diff --git a/airbyte-cdk/python/airbyte_cdk/sources/abstract_source.py b/airbyte-cdk/python/airbyte_cdk/sources/abstract_source.py index 7f641e3e3c3f..5cca5680c9de 100644 --- a/airbyte-cdk/python/airbyte_cdk/sources/abstract_source.py +++ b/airbyte-cdk/python/airbyte_cdk/sources/abstract_source.py @@ -174,13 +174,7 @@ def _read_stream( "cursor_field": configured_stream.cursor_field, }, ) - logger.debug( - f"Syncing stream instance: {stream_instance.name}", - extra={ - "primary_key": stream_instance.primary_key, - "cursor_field": stream_instance.cursor_field, - }, - ) + stream_instance.log_stream_sync_configuration() use_incremental = configured_stream.sync_mode == SyncMode.incremental and stream_instance.supports_incremental if use_incremental: @@ -294,26 +288,14 @@ def _read_full_refresh( configured_stream: ConfiguredAirbyteStream, internal_config: InternalConfig, ) -> Iterator[AirbyteMessage]: - slices = stream_instance.stream_slices(sync_mode=SyncMode.full_refresh, cursor_field=configured_stream.cursor_field) - logger.debug( - f"Processing stream slices for {configured_stream.stream.name} (sync_mode: full_refresh)", extra={"stream_slices": slices} - ) total_records_counter = 0 - for _slice in slices: - if self._slice_logger.should_log_slice_message(logger): - yield self._slice_logger.create_slice_log_message(_slice) - record_data_or_messages = stream_instance.read_records( - stream_slice=_slice, - sync_mode=SyncMode.full_refresh, - cursor_field=configured_stream.cursor_field, - ) - for record_data_or_message in record_data_or_messages: - message = self._get_message(record_data_or_message, stream_instance) - yield message - if message.type == MessageType.RECORD: - total_records_counter += 1 - if internal_config.is_limit_reached(total_records_counter): - return + for record_data_or_message in stream_instance.read_full_refresh(configured_stream.cursor_field, logger, self._slice_logger): + message = self._get_message(record_data_or_message, stream_instance) + yield message + if message.type == MessageType.RECORD: + total_records_counter += 1 + if internal_config.is_limit_reached(total_records_counter): + return def _checkpoint_state(self, stream: Stream, stream_state: Mapping[str, Any], state_manager: ConnectorStateManager) -> AirbyteMessage: # First attempt to retrieve the current state using the stream's state property. We receive an AttributeError if the state diff --git a/airbyte-cdk/python/airbyte_cdk/sources/streams/__init__.py b/airbyte-cdk/python/airbyte_cdk/sources/streams/__init__.py index 0df89f871a52..9326fd1bdca7 100644 --- a/airbyte-cdk/python/airbyte_cdk/sources/streams/__init__.py +++ b/airbyte-cdk/python/airbyte_cdk/sources/streams/__init__.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2021 Airbyte, Inc., all rights reserved. +# Copyright (c) 2023 Airbyte, Inc., all rights reserved. # # Initialize Streams Package diff --git a/airbyte-cdk/python/airbyte_cdk/sources/streams/concurrent/__init__.py b/airbyte-cdk/python/airbyte_cdk/sources/streams/concurrent/__init__.py new file mode 100644 index 000000000000..c941b3045795 --- /dev/null +++ b/airbyte-cdk/python/airbyte_cdk/sources/streams/concurrent/__init__.py @@ -0,0 +1,3 @@ +# +# Copyright (c) 2023 Airbyte, Inc., all rights reserved. +# diff --git a/airbyte-cdk/python/airbyte_cdk/sources/streams/concurrent/abstract_stream.py b/airbyte-cdk/python/airbyte_cdk/sources/streams/concurrent/abstract_stream.py new file mode 100644 index 000000000000..c394cb7621e7 --- /dev/null +++ b/airbyte-cdk/python/airbyte_cdk/sources/streams/concurrent/abstract_stream.py @@ -0,0 +1,83 @@ +# +# Copyright (c) 2023 Airbyte, Inc., all rights reserved. +# + +from abc import ABC, abstractmethod +from typing import Any, Iterable, Mapping, Optional + +from airbyte_cdk.models import AirbyteStream +from airbyte_cdk.sources.streams.concurrent.availability_strategy import StreamAvailability +from airbyte_cdk.sources.streams.concurrent.partitions.record import Record +from deprecated.classic import deprecated + + +@deprecated("This class is experimental. Use at your own risk.") +class AbstractStream(ABC): + """ + AbstractStream is an experimental interface for streams developed as part of the Concurrent CDK. + This interface is not yet stable and may change in the future. Use at your own risk. + + Why create a new interface instead of adding concurrency capabilities the existing Stream? + We learnt a lot since the initial design of the Stream interface, and we wanted to take the opportunity to improve. + + High level, the changes we are targeting are: + - Removing superfluous or leaky parameters from the methods' interfaces + - Using composition instead of inheritance to add new capabilities + + To allow us to iterate fast while ensuring backwards compatibility, we are creating a new interface with a facade object that will bridge the old and the new interfaces. + Source connectors that wish to leverage concurrency need to implement this new interface. An example will be available shortly + + Current restrictions on sources that implement this interface. Not all of these restrictions will be lifted in the future, but most will as we iterate on the design. + - Only full refresh is supported. This will be addressed in the future. + - The read method does not accept a cursor_field. Streams must be internally aware of the cursor field to use. User-defined cursor fields can be implemented by modifying the connector's main method to instantiate the streams with the configured cursor field. + - Streams cannot return user-friendly messages by overriding Stream.get_error_display_message. This will be addressed in the future. + - The Stream's behavior cannot depend on a namespace + - TypeTransformer is not supported. This will be addressed in the future. + - Nested cursor and primary keys are not supported + """ + + @abstractmethod + def read(self) -> Iterable[Record]: + """ + Read a stream in full refresh mode + :return: The stream's records + """ + + @property + @abstractmethod + def name(self) -> str: + """ + :return: The stream name + """ + + @property + @abstractmethod + def cursor_field(self) -> Optional[str]: + """ + Override to return the default cursor field used by this stream e.g: an API entity might always use created_at as the cursor field. + :return: The name of the field used as a cursor. Nested cursor fields are not supported. + """ + + @abstractmethod + def check_availability(self) -> StreamAvailability: + """ + :return: The stream's availability + """ + + @abstractmethod + def get_json_schema(self) -> Mapping[str, Any]: + """ + :return: A dict of the JSON schema representing this stream. + """ + + @abstractmethod + def as_airbyte_stream(self) -> AirbyteStream: + """ + :return: A dict of the JSON schema representing this stream. + """ + + @abstractmethod + def log_stream_sync_configuration(self) -> None: + """ + Logs the stream's configuration for debugging purposes. + """ diff --git a/airbyte-cdk/python/airbyte_cdk/sources/streams/concurrent/adapters.py b/airbyte-cdk/python/airbyte_cdk/sources/streams/concurrent/adapters.py new file mode 100644 index 000000000000..c71bb8f12872 --- /dev/null +++ b/airbyte-cdk/python/airbyte_cdk/sources/streams/concurrent/adapters.py @@ -0,0 +1,329 @@ +# +# Copyright (c) 2023 Airbyte, Inc., all rights reserved. +# + +import copy +import json +import logging +from functools import lru_cache +from typing import Any, Iterable, List, Mapping, Optional, Tuple, Union + +from airbyte_cdk.models import AirbyteStream, SyncMode +from airbyte_cdk.sources import AbstractSource, Source +from airbyte_cdk.sources.message import MessageRepository +from airbyte_cdk.sources.streams import Stream +from airbyte_cdk.sources.streams.availability_strategy import AvailabilityStrategy +from airbyte_cdk.sources.streams.concurrent.abstract_stream import AbstractStream +from airbyte_cdk.sources.streams.concurrent.availability_strategy import ( + AbstractAvailabilityStrategy, + StreamAvailability, + StreamAvailable, + StreamUnavailable, +) +from airbyte_cdk.sources.streams.concurrent.exceptions import ExceptionWithDisplayMessage +from airbyte_cdk.sources.streams.concurrent.partitions.partition import Partition +from airbyte_cdk.sources.streams.concurrent.partitions.partition_generator import PartitionGenerator +from airbyte_cdk.sources.streams.concurrent.partitions.record import Record +from airbyte_cdk.sources.streams.concurrent.thread_based_concurrent_stream import ThreadBasedConcurrentStream +from airbyte_cdk.sources.streams.core import StreamData +from airbyte_cdk.sources.utils.slice_logger import SliceLogger +from deprecated.classic import deprecated + +""" +This module contains adapters to help enabling concurrency on Stream objects without needing to migrate to AbstractStream +""" + + +@deprecated("This class is experimental. Use at your own risk.") +class StreamFacade(Stream): + """ + The StreamFacade is a Stream that wraps an AbstractStream and exposes it as a Stream. + + All methods either delegate to the wrapped AbstractStream or provide a default implementation. + The default implementations define restrictions imposed on Streams migrated to the new interface. For instance, only source-defined cursors are supported. + """ + + @classmethod + def create_from_stream(cls, stream: Stream, source: AbstractSource, logger: logging.Logger, max_workers: int) -> Stream: + """ + Create a ConcurrentStream from a Stream object. + :param source: The source + :param stream: The stream + :param max_workers: The maximum number of worker thread to use + :return: + """ + pk = cls._get_primary_key_from_stream(stream.primary_key) + cursor_field = cls._get_cursor_field_from_stream(stream) + + if not source.message_repository: + raise ValueError( + "A message repository is required to emit non-record messages. Please set the message repository on the source." + ) + + message_repository = source.message_repository + return StreamFacade( + ThreadBasedConcurrentStream( + partition_generator=StreamPartitionGenerator(stream, message_repository), + max_workers=max_workers, + name=stream.name, + json_schema=stream.get_json_schema(), + availability_strategy=StreamAvailabilityStrategy(stream, source), + primary_key=pk, + cursor_field=cursor_field, + slice_logger=source._slice_logger, + message_repository=message_repository, + logger=logger, + ) + ) + + @classmethod + def _get_primary_key_from_stream(cls, stream_primary_key: Optional[Union[str, List[str], List[List[str]]]]) -> List[str]: + if stream_primary_key is None: + return [] + elif isinstance(stream_primary_key, str): + return [stream_primary_key] + elif isinstance(stream_primary_key, list): + if len(stream_primary_key) > 0 and all(isinstance(k, str) for k in stream_primary_key): + return stream_primary_key # type: ignore # We verified all items in the list are strings + else: + raise ValueError(f"Nested primary keys are not supported. Found {stream_primary_key}") + else: + raise ValueError(f"Invalid type for primary key: {stream_primary_key}") + + @classmethod + def _get_cursor_field_from_stream(cls, stream: Stream) -> Optional[str]: + if isinstance(stream.cursor_field, list): + if len(stream.cursor_field) > 1: + raise ValueError(f"Nested cursor fields are not supported. Got {stream.cursor_field} for {stream.name}") + elif len(stream.cursor_field) == 0: + return None + else: + return stream.cursor_field[0] + else: + return stream.cursor_field + + def __init__(self, stream: AbstractStream): + """ + :param stream: The underlying AbstractStream + """ + self._abstract_stream = stream + + def read_full_refresh( + self, + cursor_field: Optional[List[str]], + logger: logging.Logger, + slice_logger: SliceLogger, + ) -> Iterable[StreamData]: + """ + Read full refresh. Delegate to the underlying AbstractStream, ignoring all the parameters + :param cursor_field: (ignored) + :param logger: (ignored) + :param slice_logger: (ignored) + :return: Iterable of StreamData + """ + for record in self._abstract_stream.read(): + yield record.data + + def read_records( + self, + sync_mode: SyncMode, + cursor_field: Optional[List[str]] = None, + stream_slice: Optional[Mapping[str, Any]] = None, + stream_state: Optional[Mapping[str, Any]] = None, + ) -> Iterable[StreamData]: + if sync_mode == SyncMode.full_refresh: + for record in self._abstract_stream.read(): + yield record.data + else: + # Incremental reads are not supported + raise NotImplementedError + + @property + def name(self) -> str: + return self._abstract_stream.name + + @property + def primary_key(self) -> Optional[Union[str, List[str], List[List[str]]]]: + # This method is not expected to be called directly. It is only implemented for backward compatibility with the old interface + return self.as_airbyte_stream().source_defined_primary_key # type: ignore # source_defined_primary_key is known to be an Optional[List[List[str]]] + + @property + def cursor_field(self) -> Union[str, List[str]]: + if self._abstract_stream.cursor_field is None: + return [] + else: + return self._abstract_stream.cursor_field + + @property + def source_defined_cursor(self) -> bool: + # Streams must be aware of their cursor at instantiation time + return True + + @lru_cache(maxsize=None) + def get_json_schema(self) -> Mapping[str, Any]: + return self._abstract_stream.get_json_schema() + + @property + def supports_incremental(self) -> bool: + # Only full refresh is supported + return False + + def check_availability(self, logger: logging.Logger, source: Optional["Source"] = None) -> Tuple[bool, Optional[str]]: + """ + Verifies the stream is available. Delegates to the underlying AbstractStream and ignores the parameters + :param logger: (ignored) + :param source: (ignored) + :return: + """ + availability = self._abstract_stream.check_availability() + return availability.is_available(), availability.message() + + def get_error_display_message(self, exception: BaseException) -> Optional[str]: + """ + Retrieves the user-friendly display message that corresponds to an exception. + This will be called when encountering an exception while reading records from the stream, and used to build the AirbyteTraceMessage. + + A display message will be returned if the exception is an instance of ExceptionWithDisplayMessage. + + :param exception: The exception that was raised + :return: A user-friendly message that indicates the cause of the error + """ + if isinstance(exception, ExceptionWithDisplayMessage): + return exception.display_message + else: + return None + + def as_airbyte_stream(self) -> AirbyteStream: + return self._abstract_stream.as_airbyte_stream() + + def log_stream_sync_configuration(self) -> None: + self._abstract_stream.log_stream_sync_configuration() + + +class StreamPartition(Partition): + """ + This class acts as an adapter between the new Partition interface and the Stream's stream_slice interface + + StreamPartitions are instantiated from a Stream and a stream_slice. + + This class can be used to help enable concurrency on existing connectors without having to rewrite everything as AbstractStream. + In the long-run, it would be preferable to update the connectors, but we don't have the tooling or need to justify the effort at this time. + """ + + def __init__(self, stream: Stream, _slice: Optional[Mapping[str, Any]], message_repository: MessageRepository): + """ + :param stream: The stream to delegate to + :param _slice: The partition's stream_slice + :param message_repository: The message repository to use to emit non-record messages + """ + self._stream = stream + self._slice = _slice + self._message_repository = message_repository + + def read(self) -> Iterable[Record]: + """ + Read messages from the stream. + If the StreamData is a Mapping, it will be converted to a Record. + Otherwise, the message will be emitted on the message repository. + """ + try: + for record_data in self._stream.read_records(sync_mode=SyncMode.full_refresh, stream_slice=copy.deepcopy(self._slice)): + if isinstance(record_data, Mapping): + yield Record(record_data) + else: + self._message_repository.emit_message(record_data) + except Exception as e: + display_message = self._stream.get_error_display_message(e) + if display_message: + raise ExceptionWithDisplayMessage(display_message) from e + else: + raise e + + def to_slice(self) -> Optional[Mapping[str, Any]]: + return self._slice + + def __hash__(self) -> int: + if self._slice: + # Convert the slice to a string so that it can be hashed + s = json.dumps(self._slice, sort_keys=True) + return hash((self._stream.name, s)) + else: + return hash(self._stream.name) + + def __repr__(self) -> str: + return f"StreamPartition({self._stream.name}, {self._slice})" + + +class StreamPartitionGenerator(PartitionGenerator): + """ + This class acts as an adapter between the new PartitionGenerator and Stream.stream_slices + + This class can be used to help enable concurrency on existing connectors without having to rewrite everything as AbstractStream. + In the long-run, it would be preferable to update the connectors, but we don't have the tooling or need to justify the effort at this time. + """ + + def __init__(self, stream: Stream, message_repository: MessageRepository): + """ + :param stream: The stream to delegate to + :param message_repository: The message repository to use to emit non-record messages + """ + self.message_repository = message_repository + self._stream = stream + + def generate(self, sync_mode: SyncMode) -> Iterable[Partition]: + for s in self._stream.stream_slices(sync_mode=sync_mode): + yield StreamPartition(self._stream, copy.deepcopy(s), self.message_repository) + + +@deprecated("This class is experimental. Use at your own risk.") +class AvailabilityStrategyFacade(AvailabilityStrategy): + def __init__(self, abstract_availability_strategy: AbstractAvailabilityStrategy): + self._abstract_availability_strategy = abstract_availability_strategy + + def check_availability(self, stream: Stream, logger: logging.Logger, source: Optional[Source]) -> Tuple[bool, Optional[str]]: + """ + Checks stream availability. + + Important to note that the stream and source parameters are not used by the underlying AbstractAvailabilityStrategy. + + :param stream: (unused) + :param logger: logger object to use + :param source: (unused) + :return: A tuple of (boolean, str). If boolean is true, then the stream + """ + stream_availability = self._abstract_availability_strategy.check_availability(logger) + return stream_availability.is_available(), stream_availability.message() + + +class StreamAvailabilityStrategy(AbstractAvailabilityStrategy): + """ + This class acts as an adapter between the existing AvailabilityStrategy and the new AbstractAvailabilityStrategy. + StreamAvailabilityStrategy is instantiated with a Stream and a Source to allow the existing AvailabilityStrategy to be used with the new AbstractAvailabilityStrategy interface. + + A more convenient implementation would not depend on the docs URL instead of the Source itself, and would support running on an AbstractStream instead of only on a Stream. + + This class can be used to help enable concurrency on existing connectors without having to rewrite everything as AbstractStream and AbstractAvailabilityStrategy. + In the long-run, it would be preferable to update the connectors, but we don't have the tooling or need to justify the effort at this time. + """ + + def __init__(self, stream: Stream, source: Source): + """ + :param stream: The stream to delegate to + :param source: The source to delegate to + """ + self._stream = stream + self._source = source + + def check_availability(self, logger: logging.Logger) -> StreamAvailability: + try: + available, message = self._stream.check_availability(logger, self._source) + if available: + return StreamAvailable() + else: + return StreamUnavailable(str(message)) + except Exception as e: + display_message = self._stream.get_error_display_message(e) + if display_message: + raise ExceptionWithDisplayMessage(display_message) + else: + raise e diff --git a/airbyte-cdk/python/airbyte_cdk/sources/streams/concurrent/availability_strategy.py b/airbyte-cdk/python/airbyte_cdk/sources/streams/concurrent/availability_strategy.py new file mode 100644 index 000000000000..b65803e09df2 --- /dev/null +++ b/airbyte-cdk/python/airbyte_cdk/sources/streams/concurrent/availability_strategy.py @@ -0,0 +1,66 @@ +# +# Copyright (c) 2023 Airbyte, Inc., all rights reserved. +# + +import logging +from abc import ABC, abstractmethod +from typing import Optional + +from deprecated.classic import deprecated + + +class StreamAvailability(ABC): + @abstractmethod + def is_available(self) -> bool: + """ + :return: True if the stream is available. False if the stream is not + """ + + @abstractmethod + def message(self) -> Optional[str]: + """ + :return: A message describing why the stream is not available. If the stream is available, this should return None. + """ + + +class StreamAvailable(StreamAvailability): + def is_available(self) -> bool: + return True + + def message(self) -> Optional[str]: + return None + + +class StreamUnavailable(StreamAvailability): + def __init__(self, message: str): + self._message = message + + def is_available(self) -> bool: + return False + + def message(self) -> Optional[str]: + return self._message + + +# Singleton instances of StreamAvailability to avoid the overhead of creating new dummy objects +STREAM_AVAILABLE = StreamAvailable() + + +@deprecated("This class is experimental. Use at your own risk.") +class AbstractAvailabilityStrategy(ABC): + """ + AbstractAvailabilityStrategy is an experimental interface developed as part of the Concurrent CDK. + This interface is not yet stable and may change in the future. Use at your own risk. + + Why create a new interface instead of using the existing AvailabilityStrategy? + The existing AvailabilityStrategy is tightly coupled with Stream and Source, which yields to circular dependencies and makes it difficult to move away from the Stream interface to AbstractStream. + """ + + @abstractmethod + def check_availability(self, logger: logging.Logger) -> StreamAvailability: + """ + Checks stream availability. + + :param logger: logger object to use + :return: A StreamAvailability object describing the stream's availability + """ diff --git a/airbyte-cdk/python/airbyte_cdk/sources/streams/concurrent/exceptions.py b/airbyte-cdk/python/airbyte_cdk/sources/streams/concurrent/exceptions.py new file mode 100644 index 000000000000..c67c2c58311d --- /dev/null +++ b/airbyte-cdk/python/airbyte_cdk/sources/streams/concurrent/exceptions.py @@ -0,0 +1,15 @@ +# +# Copyright (c) 2023 Airbyte, Inc., all rights reserved. +# + +from typing import Any + + +class ExceptionWithDisplayMessage(Exception): + """ + Exception that can be used to display a custom message to the user. + """ + + def __init__(self, display_message: str, **kwargs: Any): + super().__init__(**kwargs) + self.display_message = display_message diff --git a/airbyte-cdk/python/airbyte_cdk/sources/streams/concurrent/partition_enqueuer.py b/airbyte-cdk/python/airbyte_cdk/sources/streams/concurrent/partition_enqueuer.py new file mode 100644 index 000000000000..b4c377e2c12c --- /dev/null +++ b/airbyte-cdk/python/airbyte_cdk/sources/streams/concurrent/partition_enqueuer.py @@ -0,0 +1,37 @@ +# +# Copyright (c) 2023 Airbyte, Inc., all rights reserved. +# + +from queue import Queue + +from airbyte_cdk.models import SyncMode +from airbyte_cdk.sources.streams.concurrent.partitions.partition_generator import PartitionGenerator +from airbyte_cdk.sources.streams.concurrent.partitions.types import PARTITIONS_GENERATED_SENTINEL, QueueItem + + +class PartitionEnqueuer: + """ + Generates partitions from a partition generator and puts them in a queue. + """ + + def __init__(self, queue: Queue[QueueItem], sentinel: PARTITIONS_GENERATED_SENTINEL) -> None: + """ + :param queue: The queue to put the partitions in. + :param sentinel: The sentinel to put in the queue when all the partitions have been generated. + """ + self._queue = queue + self._sentinel = sentinel + + def generate_partitions(self, partition_generator: PartitionGenerator, sync_mode: SyncMode) -> None: + """ + Generate partitions from a partition generator and put them in a queue. + When all the partitions are added to the queue, a sentinel is added to the queue to indicate that all the partitions have been generated. + + This method is meant to be called in a separate thread. + :param partition_generator: The partition Generator + :param sync_mode: The sync mode used + :return: + """ + for partition in partition_generator.generate(sync_mode=sync_mode): + self._queue.put(partition) + self._queue.put(self._sentinel) diff --git a/airbyte-cdk/python/airbyte_cdk/sources/streams/concurrent/partition_reader.py b/airbyte-cdk/python/airbyte_cdk/sources/streams/concurrent/partition_reader.py new file mode 100644 index 000000000000..ce13b48dc56b --- /dev/null +++ b/airbyte-cdk/python/airbyte_cdk/sources/streams/concurrent/partition_reader.py @@ -0,0 +1,33 @@ +# +# Copyright (c) 2023 Airbyte, Inc., all rights reserved. +# + +from queue import Queue + +from airbyte_cdk.sources.streams.concurrent.partitions.partition import Partition +from airbyte_cdk.sources.streams.concurrent.partitions.types import PartitionCompleteSentinel, QueueItem + + +class PartitionReader: + """ + Generates records from a partition and puts them in a queuea. + """ + + def __init__(self, queue: Queue[QueueItem]) -> None: + """ + :param queue: The queue to put the records in. + """ + self._queue = queue + + def process_partition(self, partition: Partition) -> None: + """ + Process a partition and put the records in the output queue. + When all the partitions are added to the queue, a sentinel is added to the queue to indicate that all the partitions have been generated. + + This method is meant to be called from a thread. + :param partition: The partition to read data from + :return: None + """ + for record in partition.read(): + self._queue.put(record) + self._queue.put(PartitionCompleteSentinel(partition)) diff --git a/airbyte-cdk/python/airbyte_cdk/sources/streams/concurrent/partitions/__init__.py b/airbyte-cdk/python/airbyte_cdk/sources/streams/concurrent/partitions/__init__.py new file mode 100644 index 000000000000..c941b3045795 --- /dev/null +++ b/airbyte-cdk/python/airbyte_cdk/sources/streams/concurrent/partitions/__init__.py @@ -0,0 +1,3 @@ +# +# Copyright (c) 2023 Airbyte, Inc., all rights reserved. +# diff --git a/airbyte-cdk/python/airbyte_cdk/sources/streams/concurrent/partitions/partition.py b/airbyte-cdk/python/airbyte_cdk/sources/streams/concurrent/partitions/partition.py new file mode 100644 index 000000000000..ac9121b4ba1c --- /dev/null +++ b/airbyte-cdk/python/airbyte_cdk/sources/streams/concurrent/partitions/partition.py @@ -0,0 +1,37 @@ +# +# Copyright (c) 2023 Airbyte, Inc., all rights reserved. +# + +from abc import ABC, abstractmethod +from typing import Any, Iterable, Mapping, Optional + +from airbyte_cdk.sources.streams.concurrent.partitions.record import Record + + +class Partition(ABC): + """ + A partition is responsible for reading a specific set of data from a source. + """ + + @abstractmethod + def read(self) -> Iterable[Record]: + """ + Reads the data from the partition. + :return: An iterable of records. + """ + pass + + @abstractmethod + def to_slice(self) -> Optional[Mapping[str, Any]]: + """ + Converts the partition to a slice that can be serialized and deserialized. + :return: A mapping representing a slice + """ + pass + + @abstractmethod + def __hash__(self) -> int: + """ + Returns a hash of the partition. + Partitions must be hashable so that they can be used as keys in a dictionary. + """ diff --git a/airbyte-cdk/python/airbyte_cdk/sources/streams/concurrent/partitions/partition_generator.py b/airbyte-cdk/python/airbyte_cdk/sources/streams/concurrent/partitions/partition_generator.py new file mode 100644 index 000000000000..134209467327 --- /dev/null +++ b/airbyte-cdk/python/airbyte_cdk/sources/streams/concurrent/partitions/partition_generator.py @@ -0,0 +1,20 @@ +# +# Copyright (c) 2023 Airbyte, Inc., all rights reserved. +# + +from abc import ABC, abstractmethod +from typing import Iterable + +from airbyte_cdk.models import SyncMode +from airbyte_cdk.sources.streams.concurrent.partitions.partition import Partition + + +class PartitionGenerator(ABC): + @abstractmethod + def generate(self, sync_mode: SyncMode) -> Iterable[Partition]: + """ + Generates partitions for a given sync mode. + :param sync_mode: SyncMode + :return: An iterable of partitions + """ + pass diff --git a/airbyte-cdk/python/airbyte_cdk/sources/streams/concurrent/partitions/record.py b/airbyte-cdk/python/airbyte_cdk/sources/streams/concurrent/partitions/record.py new file mode 100644 index 000000000000..ddc58b654fd2 --- /dev/null +++ b/airbyte-cdk/python/airbyte_cdk/sources/streams/concurrent/partitions/record.py @@ -0,0 +1,19 @@ +# +# Copyright (c) 2023 Airbyte, Inc., all rights reserved. +# + +from typing import Any, Mapping + + +class Record: + """ + Represents a record read from a stream. + """ + + def __init__(self, data: Mapping[str, Any]): + self.data = data + + def __eq__(self, other: Any) -> bool: + if not isinstance(other, Record): + return False + return self.data == other.data diff --git a/airbyte-cdk/python/airbyte_cdk/sources/streams/concurrent/partitions/types.py b/airbyte-cdk/python/airbyte_cdk/sources/streams/concurrent/partitions/types.py new file mode 100644 index 000000000000..d705555c857f --- /dev/null +++ b/airbyte-cdk/python/airbyte_cdk/sources/streams/concurrent/partitions/types.py @@ -0,0 +1,29 @@ +# +# Copyright (c) 2023 Airbyte, Inc., all rights reserved. +# + +from typing import Union + +from airbyte_cdk.sources.streams.concurrent.partitions.partition import Partition +from airbyte_cdk.sources.streams.concurrent.partitions.record import Record + +PARTITIONS_GENERATED_SENTINEL = object + + +class PartitionCompleteSentinel: + """ + A sentinel object indicating all records for a partition were produced. + Includes a pointer to the partition that was processed. + """ + + def __init__(self, partition: Partition): + """ + :param partition: The partition that was processed + """ + self.partition = partition + + +""" +Typedef representing the items that can be added to the ThreadBasedConcurrentStream +""" +QueueItem = Union[Record, Partition, PartitionCompleteSentinel, PARTITIONS_GENERATED_SENTINEL, Partition] diff --git a/airbyte-cdk/python/airbyte_cdk/sources/streams/concurrent/thread_based_concurrent_stream.py b/airbyte-cdk/python/airbyte_cdk/sources/streams/concurrent/thread_based_concurrent_stream.py new file mode 100644 index 000000000000..71213000ab87 --- /dev/null +++ b/airbyte-cdk/python/airbyte_cdk/sources/streams/concurrent/thread_based_concurrent_stream.py @@ -0,0 +1,172 @@ +# +# Copyright (c) 2023 Airbyte, Inc., all rights reserved. +# + +import concurrent +import time +from concurrent.futures import Future +from functools import lru_cache +from logging import Logger +from queue import Queue +from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional + +from airbyte_cdk.models import AirbyteStream, SyncMode +from airbyte_cdk.sources.message import MessageRepository +from airbyte_cdk.sources.streams.concurrent.abstract_stream import AbstractStream +from airbyte_cdk.sources.streams.concurrent.availability_strategy import AbstractAvailabilityStrategy, StreamAvailability +from airbyte_cdk.sources.streams.concurrent.partition_enqueuer import PartitionEnqueuer +from airbyte_cdk.sources.streams.concurrent.partition_reader import PartitionReader +from airbyte_cdk.sources.streams.concurrent.partitions.partition import Partition +from airbyte_cdk.sources.streams.concurrent.partitions.partition_generator import PartitionGenerator +from airbyte_cdk.sources.streams.concurrent.partitions.record import Record +from airbyte_cdk.sources.streams.concurrent.partitions.types import PARTITIONS_GENERATED_SENTINEL, PartitionCompleteSentinel, QueueItem +from airbyte_cdk.sources.utils.slice_logger import SliceLogger + + +class ThreadBasedConcurrentStream(AbstractStream): + + DEFAULT_TIMEOUT_SECONDS = 300 + DEFAULT_MAX_QUEUE_SIZE = 10_000 + DEFAULT_SLEEP_TIME = 0.1 + + def __init__( + self, + partition_generator: PartitionGenerator, + max_workers: int, + name: str, + json_schema: Mapping[str, Any], + availability_strategy: AbstractAvailabilityStrategy, + primary_key: List[str], + cursor_field: Optional[str], + slice_logger: SliceLogger, + logger: Logger, + message_repository: MessageRepository, + timeout_seconds: int = DEFAULT_TIMEOUT_SECONDS, + max_concurrent_tasks: int = DEFAULT_MAX_QUEUE_SIZE, + sleep_time: float = DEFAULT_SLEEP_TIME, + ): + self._stream_partition_generator = partition_generator + self._max_workers = max_workers + self._threadpool = concurrent.futures.ThreadPoolExecutor(max_workers=self._max_workers, thread_name_prefix="workerpool") + self._name = name + self._json_schema = json_schema + self._availability_strategy = availability_strategy + self._primary_key = primary_key + self._cursor_field = cursor_field + self._slice_logger = slice_logger + self._logger = logger + self._message_repository = message_repository + self._timeout_seconds = timeout_seconds + self._max_concurrent_tasks = max_concurrent_tasks + self._sleep_time = sleep_time + + def read(self) -> Iterable[Record]: + """ + Read all data from the stream (only full-refresh is supported at the moment) + + Algorithm: + 1. Submit a future to generate the stream's partition to process. + - This has to be done asynchronously because we sometimes need to submit requests to the API to generate all partitions (eg for substreams). + - The future will add the partitions to process on a work queue + 2. Continuously poll work from the work queue until all partitions are generated and processed + - If the next work item is a partition, submit a future to process it. + - The future will add the records to emit on the work queue + - Add the partitions to the partitions_to_done dict so we know it needs to complete for the sync to succeed + - If the next work item is a record, yield the record + - If the next work item is PARTITIONS_GENERATED_SENTINEL, all the partitions were generated + - If the next work item is a PartitionCompleteSentinel, a partition is done processing + - Update the value in partitions_to_done to True so we know the partition is completed + """ + self._logger.debug(f"Processing stream slices for {self.name} (sync_mode: full_refresh)") + futures: List[Future[Any]] = [] + queue: Queue[QueueItem] = Queue() + partition_generator = PartitionEnqueuer(queue, PARTITIONS_GENERATED_SENTINEL) + partition_reader = PartitionReader(queue) + + # Submit partition generation tasks + self._submit_task(futures, partition_generator.generate_partitions, self._stream_partition_generator, SyncMode.full_refresh) + + # True -> partition is done + # False -> partition is not done + partitions_to_done: Dict[Partition, bool] = {} + + finished_partitions = False + while record_or_partition := queue.get(block=True, timeout=self._timeout_seconds): + if record_or_partition == PARTITIONS_GENERATED_SENTINEL: + # All partitions were generated + finished_partitions = True + elif isinstance(record_or_partition, PartitionCompleteSentinel): + # All records for a partition were generated + if record_or_partition.partition not in partitions_to_done: + raise RuntimeError( + f"Received sentinel for partition {record_or_partition.partition} that was not in partitions. This is indicative of a bug in the CDK. Please contact support.partitions:\n{partitions_to_done}" + ) + partitions_to_done[record_or_partition.partition] = True + elif isinstance(record_or_partition, Record): + # Emit records + yield record_or_partition + elif isinstance(record_or_partition, Partition): + # A new partition was generated and must be processed + partitions_to_done[record_or_partition] = False + if self._slice_logger.should_log_slice_message(self._logger): + self._message_repository.emit_message(self._slice_logger.create_slice_log_message(record_or_partition.to_slice())) + self._submit_task(futures, partition_reader.process_partition, record_or_partition) + if finished_partitions and all(partitions_to_done.values()): + # All partitions were generated and process. We're done here + break + self._check_for_errors(futures) + + def _submit_task(self, futures: List[Future[Any]], function: Callable[..., Any], *args: Any) -> None: + # Submit a task to the threadpool, waiting if there are too many pending tasks + self._wait_while_too_many_pending_futures(futures) + futures.append(self._threadpool.submit(function, *args)) + + def _wait_while_too_many_pending_futures(self, futures: List[Future[Any]]) -> None: + # Wait until the number of pending tasks is < self._max_concurrent_tasks + while True: + pending_futures = [f for f in futures if not f.done()] + if len(pending_futures) < self._max_concurrent_tasks: + break + self._logger.info("Main thread is sleeping because the task queue is full...") + time.sleep(self._sleep_time) + + def _check_for_errors(self, futures: List[Future[Any]]) -> None: + exceptions_from_futures = [f for f in [future.exception() for future in futures] if f is not None] + if exceptions_from_futures: + raise RuntimeError(f"Failed reading from stream {self.name} with errors: {exceptions_from_futures}") + futures_not_done = [f for f in futures if not f.done()] + if futures_not_done: + raise RuntimeError(f"Failed reading from stream {self.name} with futures not done: {futures_not_done}") + + @property + def name(self) -> str: + return self._name + + def check_availability(self) -> StreamAvailability: + return self._availability_strategy.check_availability(self._logger) + + @property + def cursor_field(self) -> Optional[str]: + return self._cursor_field + + @lru_cache(maxsize=None) + def get_json_schema(self) -> Mapping[str, Any]: + return self._json_schema + + def as_airbyte_stream(self) -> AirbyteStream: + stream = AirbyteStream(name=self.name, json_schema=dict(self._json_schema), supported_sync_modes=[SyncMode.full_refresh]) + + keys = self._primary_key + if keys and len(keys) > 0: + stream.source_defined_primary_key = keys + + return stream + + def log_stream_sync_configuration(self) -> None: + self._logger.debug( + f"Syncing stream instance: {self.name}", + extra={ + "primary_key": self._primary_key, + "cursor_field": self.cursor_field, + }, + ) diff --git a/airbyte-cdk/python/airbyte_cdk/sources/streams/core.py b/airbyte-cdk/python/airbyte_cdk/sources/streams/core.py index 03698afa5747..2f2fde6c65d4 100644 --- a/airbyte-cdk/python/airbyte_cdk/sources/streams/core.py +++ b/airbyte-cdk/python/airbyte_cdk/sources/streams/core.py @@ -15,6 +15,7 @@ # list of all possible HTTP methods which can be used for sending of request bodies from airbyte_cdk.sources.utils.schema_helpers import ResourceSchemaLoader +from airbyte_cdk.sources.utils.slice_logger import SliceLogger from airbyte_cdk.sources.utils.transform import TransformConfig, TypeTransformer from deprecated.classic import deprecated @@ -105,6 +106,23 @@ def get_error_display_message(self, exception: BaseException) -> Optional[str]: """ return None + def read_full_refresh( + self, + cursor_field: Optional[List[str]], + logger: logging.Logger, + slice_logger: SliceLogger, + ) -> Iterable[StreamData]: + slices = self.stream_slices(sync_mode=SyncMode.full_refresh, cursor_field=cursor_field) + logger.debug(f"Processing stream slices for {self.name} (sync_mode: full_refresh)", extra={"stream_slices": slices}) + for _slice in slices: + if slice_logger.should_log_slice_message(logger): + yield slice_logger.create_slice_log_message(_slice) + yield from self.read_records( + stream_slice=_slice, + sync_mode=SyncMode.full_refresh, + cursor_field=cursor_field, + ) + @abstractmethod def read_records( self, @@ -252,6 +270,18 @@ def get_updated_state( """ return {} + def log_stream_sync_configuration(self) -> None: + """ + Logs the configuration of this stream. + """ + self.logger.debug( + f"Syncing stream instance: {self.name}", + extra={ + "primary_key": self.primary_key, + "cursor_field": self.cursor_field, + }, + ) + @staticmethod def _wrapped_primary_key(keys: Optional[Union[str, List[str], List[List[str]]]]) -> Optional[List[List[str]]]: """ diff --git a/airbyte-cdk/python/unit_tests/sources/streams/concurrent/__init__.py b/airbyte-cdk/python/unit_tests/sources/streams/concurrent/__init__.py new file mode 100644 index 000000000000..c941b3045795 --- /dev/null +++ b/airbyte-cdk/python/unit_tests/sources/streams/concurrent/__init__.py @@ -0,0 +1,3 @@ +# +# Copyright (c) 2023 Airbyte, Inc., all rights reserved. +# diff --git a/airbyte-cdk/python/unit_tests/sources/streams/concurrent/test_adapters.py b/airbyte-cdk/python/unit_tests/sources/streams/concurrent/test_adapters.py new file mode 100644 index 000000000000..77f2b31e7bb0 --- /dev/null +++ b/airbyte-cdk/python/unit_tests/sources/streams/concurrent/test_adapters.py @@ -0,0 +1,336 @@ +# +# Copyright (c) 2023 Airbyte, Inc., all rights reserved. +# + +import unittest +from unittest.mock import Mock + +import pytest +from airbyte_cdk.models import AirbyteLogMessage, AirbyteMessage, AirbyteStream, Level, SyncMode +from airbyte_cdk.models import Type as MessageType +from airbyte_cdk.sources.message import InMemoryMessageRepository +from airbyte_cdk.sources.streams.concurrent.adapters import ( + AvailabilityStrategyFacade, + StreamAvailabilityStrategy, + StreamFacade, + StreamPartition, + StreamPartitionGenerator, +) +from airbyte_cdk.sources.streams.concurrent.availability_strategy import STREAM_AVAILABLE, StreamAvailable, StreamUnavailable +from airbyte_cdk.sources.streams.concurrent.exceptions import ExceptionWithDisplayMessage +from airbyte_cdk.sources.streams.concurrent.partitions.record import Record + + +@pytest.mark.parametrize( + "stream_availability, expected_available, expected_message", + [ + pytest.param(StreamAvailable(), True, None, id="test_stream_is_available"), + pytest.param(STREAM_AVAILABLE, True, None, id="test_stream_is_available_using_singleton"), + pytest.param(StreamUnavailable("message"), False, "message", id="test_stream_is_available"), + ], +) +def test_availability_strategy_facade(stream_availability, expected_available, expected_message): + strategy = Mock() + strategy.check_availability.return_value = stream_availability + facade = AvailabilityStrategyFacade(strategy) + + logger = Mock() + available, message = facade.check_availability(Mock(), logger, Mock()) + + assert available == expected_available + assert message == expected_message + + strategy.check_availability.assert_called_once_with(logger) + + +def test_stream_availability_strategy(): + stream = Mock() + source = Mock() + stream.check_availability.return_value = True, None + logger = Mock() + availability_strategy = StreamAvailabilityStrategy(stream, source) + + stream_availability = availability_strategy.check_availability(logger) + assert stream_availability.is_available() + assert stream_availability.message() is None + + stream.check_availability.assert_called_once_with(logger, source) + + +@pytest.mark.parametrize( + "sync_mode", + [ + pytest.param(SyncMode.full_refresh, id="test_full_refresh"), + pytest.param(SyncMode.incremental, id="test_incremental"), + ], +) +def test_stream_partition_generator(sync_mode): + stream = Mock() + message_repository = Mock() + stream_slices = [{"slice": 1}, {"slice": 2}] + stream.stream_slices.return_value = stream_slices + + partition_generator = StreamPartitionGenerator(stream, message_repository) + + partitions = list(partition_generator.generate(sync_mode)) + slices = [partition.to_slice() for partition in partitions] + assert slices == stream_slices + + +def test_stream_partition(): + stream = Mock() + message_repository = InMemoryMessageRepository() + _slice = None + partition = StreamPartition(stream, _slice, message_repository) + + a_log_message = AirbyteMessage( + type=MessageType.LOG, + log=AirbyteLogMessage( + level=Level.INFO, + message='slice:{"partition": 1}', + ), + ) + + stream_data = [a_log_message, {"data": 1}, {"data": 2}] + stream.read_records.return_value = stream_data + + records = list(partition.read()) + messages = list(message_repository.consume_queue()) + + expected_records = [ + Record({"data": 1}), + Record({"data": 2}), + ] + + assert records == expected_records + assert messages == [a_log_message] + + +@pytest.mark.parametrize( + "exception_type, expected_display_message", + [ + pytest.param(Exception, None, id="test_exception_no_display_message"), + pytest.param(ExceptionWithDisplayMessage, "display_message", id="test_exception_no_display_message"), + ], +) +def test_stream_partition_raising_exception(exception_type, expected_display_message): + stream = Mock() + stream.get_error_display_message.return_value = expected_display_message + + message_repository = InMemoryMessageRepository() + _slice = None + partition = StreamPartition(stream, _slice, message_repository) + + stream.read_records.side_effect = Exception() + + with pytest.raises(exception_type) as e: + list(partition.read()) + if isinstance(e, ExceptionWithDisplayMessage): + assert e.display_message == "display message" + + +@pytest.mark.parametrize( + "_slice, expected_hash", + [ + pytest.param({"partition": 1, "k": "v"}, hash(("stream", '{"k": "v", "partition": 1}')), id="test_hash_with_slice"), + pytest.param(None, hash("stream"), id="test_hash_no_slice"), + ], +) +def test_stream_partition_hash(_slice, expected_hash): + stream = Mock() + stream.name = "stream" + partition = StreamPartition(stream, _slice, Mock()) + + _hash = partition.__hash__() + assert _hash == expected_hash + + +class StreamFacadeTest(unittest.TestCase): + def setUp(self): + self._abstract_stream = Mock() + self._abstract_stream.name = "stream" + self._abstract_stream.as_airbyte_stream.return_value = AirbyteStream( + name="stream", + json_schema={"type": "object"}, + supported_sync_modes=[SyncMode.full_refresh], + ) + self._facade = StreamFacade(self._abstract_stream) + self._logger = Mock() + self._source = Mock() + self._max_workers = 10 + + self._stream = Mock() + self._stream.primary_key = "id" + + def test_name_is_delegated_to_wrapped_stream(self): + assert self._facade.name == self._abstract_stream.name + + def test_cursor_field_is_a_string(self): + self._abstract_stream.cursor_field = "cursor_field" + assert self._facade.cursor_field == "cursor_field" + + def test_none_cursor_field_is_converted_to_an_empty_list(self): + self._abstract_stream.cursor_field = None + assert self._facade.cursor_field == [] + + def test_source_defined_cursor_is_true(self): + assert self._facade.source_defined_cursor + + def test_json_schema_is_delegated_to_wrapped_stream(self): + json_schema = {"type": "object"} + self._abstract_stream.get_json_schema.return_value = json_schema + assert self._facade.get_json_schema() == json_schema + self._abstract_stream.get_json_schema.assert_called_once_with() + + def test_supports_incremental_is_false(self): + assert self._facade.supports_incremental is False + + def test_check_availability_is_delegated_to_wrapped_stream(self): + availability = StreamAvailable() + self._abstract_stream.check_availability.return_value = availability + assert self._facade.check_availability(Mock(), Mock()) == (availability.is_available(), availability.message()) + self._abstract_stream.check_availability.assert_called_once_with() + + def test_full_refresh(self): + expected_stream_data = [{"data": 1}, {"data": 2}] + records = [Record(data) for data in expected_stream_data] + self._abstract_stream.read.return_value = records + + actual_stream_data = list(self._facade.read_records(SyncMode.full_refresh, None, None, None)) + + assert actual_stream_data == expected_stream_data + + def test_read_records_full_refresh(self): + expected_stream_data = [{"data": 1}, {"data": 2}] + records = [Record(data) for data in expected_stream_data] + self._abstract_stream.read.return_value = records + + actual_stream_data = list(self._facade.read_full_refresh(None, None, None)) + + assert actual_stream_data == expected_stream_data + + def test_read_records_incremental(self): + with self.assertRaises(NotImplementedError): + list(self._facade.read_records(SyncMode.incremental, None, None, None)) + + def test_create_from_stream_stream(self): + stream = Mock() + stream.name = "stream" + stream.primary_key = "id" + stream.cursor_field = "cursor" + + facade = StreamFacade.create_from_stream(stream, self._source, self._logger, self._max_workers) + + assert facade.name == "stream" + assert facade.cursor_field == "cursor" + assert facade._abstract_stream._primary_key == ["id"] + + def test_create_from_stream_stream_with_none_primary_key(self): + stream = Mock() + stream.name = "stream" + stream.primary_key = None + stream.cursor_field = [] + + facade = StreamFacade.create_from_stream(stream, self._source, self._logger, self._max_workers) + facade._abstract_stream._primary_key is None + + def test_create_from_stream_with_composite_primary_key(self): + stream = Mock() + stream.name = "stream" + stream.primary_key = ["id", "name"] + stream.cursor_field = [] + + facade = StreamFacade.create_from_stream(stream, self._source, self._logger, self._max_workers) + facade._abstract_stream._primary_key == ["id", "name"] + + def test_create_from_stream_with_empty_list_cursor(self): + stream = Mock() + stream.primary_key = "id" + stream.cursor_field = [] + + facade = StreamFacade.create_from_stream(stream, self._source, self._logger, self._max_workers) + + assert facade.cursor_field == [] + + def test_create_from_stream_raises_exception_if_primary_key_is_nested(self): + stream = Mock() + stream.name = "stream" + stream.primary_key = [["field", "id"]] + + with self.assertRaises(ValueError): + StreamFacade.create_from_stream(stream, self._source, self._logger, self._max_workers) + + def test_create_from_stream_raises_exception_if_primary_key_has_invalid_type(self): + stream = Mock() + stream.name = "stream" + stream.primary_key = 123 + + with self.assertRaises(ValueError): + StreamFacade.create_from_stream(stream, self._source, self._logger, self._max_workers) + + def test_create_from_stream_raises_exception_if_cursor_field_is_nested(self): + stream = Mock() + stream.name = "stream" + stream.primary_key = "id" + stream.cursor_field = ["field", "cursor"] + + with self.assertRaises(ValueError): + StreamFacade.create_from_stream(stream, self._source, self._logger, self._max_workers) + + def test_create_from_stream_with_cursor_field_as_list(self): + stream = Mock() + stream.name = "stream" + stream.primary_key = "id" + stream.cursor_field = ["cursor"] + + facade = StreamFacade.create_from_stream(stream, self._source, self._logger, self._max_workers) + assert facade.cursor_field == "cursor" + + def test_create_from_stream_none_message_repository(self): + self._stream.name = "stream" + self._stream.primary_key = "id" + self._stream.cursor_field = "cursor" + self._source.message_repository = None + + with self.assertRaises(ValueError): + StreamFacade.create_from_stream(self._stream, self._source, self._logger, self._max_workers) + + def test_get_error_display_message_no_display_message(self): + self._stream.get_error_display_message.return_value = "display_message" + + facade = StreamFacade.create_from_stream(self._stream, self._source, self._logger, self._max_workers) + + expected_display_message = None + e = Exception() + + display_message = facade.get_error_display_message(e) + + assert expected_display_message == display_message + + def test_get_error_display_message_with_display_message(self): + self._stream.get_error_display_message.return_value = "display_message" + + facade = StreamFacade.create_from_stream(self._stream, self._source, self._logger, self._max_workers) + + expected_display_message = "display_message" + e = ExceptionWithDisplayMessage("display_message") + + display_message = facade.get_error_display_message(e) + + assert expected_display_message == display_message + + +@pytest.mark.parametrize( + "exception, expected_display_message", + [ + pytest.param(Exception("message"), None, id="test_no_display_message"), + pytest.param(ExceptionWithDisplayMessage("message"), "message", id="test_no_display_message"), + ], +) +def test_get_error_display_message(exception, expected_display_message): + stream = Mock() + facade = StreamFacade(stream) + + display_message = facade.get_error_display_message(exception) + + assert display_message == expected_display_message diff --git a/airbyte-cdk/python/unit_tests/sources/streams/concurrent/test_concurrent_partition_generator.py b/airbyte-cdk/python/unit_tests/sources/streams/concurrent/test_concurrent_partition_generator.py new file mode 100644 index 000000000000..aeaacf525072 --- /dev/null +++ b/airbyte-cdk/python/unit_tests/sources/streams/concurrent/test_concurrent_partition_generator.py @@ -0,0 +1,37 @@ +# +# Copyright (c) 2023 Airbyte, Inc., all rights reserved. +# + +from queue import Queue +from unittest.mock import Mock + +import pytest +from airbyte_cdk.models import SyncMode +from airbyte_cdk.sources.streams.concurrent.adapters import StreamPartition +from airbyte_cdk.sources.streams.concurrent.partition_enqueuer import PartitionEnqueuer +from airbyte_cdk.sources.streams.concurrent.partitions.types import PARTITIONS_GENERATED_SENTINEL + + +@pytest.mark.parametrize( + "slices", [pytest.param([], id="test_no_partitions"), pytest.param([{"partition": 1}, {"partition": 2}], id="test_two_partitions")] +) +def test_partition_generator(slices): + queue = Queue() + partition_generator = PartitionEnqueuer(queue, PARTITIONS_GENERATED_SENTINEL) + + stream = Mock() + message_repository = Mock() + partitions = [StreamPartition(stream, s, message_repository) for s in slices] + stream.generate.return_value = iter(partitions) + + sync_mode = SyncMode.full_refresh + + partition_generator.generate_partitions(stream, sync_mode) + + actual_partitions = [] + while partition := queue.get(False): + if partition == PARTITIONS_GENERATED_SENTINEL: + break + actual_partitions.append(partition) + + assert actual_partitions == partitions diff --git a/airbyte-cdk/python/unit_tests/sources/streams/concurrent/test_partition_reader.py b/airbyte-cdk/python/unit_tests/sources/streams/concurrent/test_partition_reader.py new file mode 100644 index 000000000000..77845d7fb0ab --- /dev/null +++ b/airbyte-cdk/python/unit_tests/sources/streams/concurrent/test_partition_reader.py @@ -0,0 +1,32 @@ +# +# Copyright (c) 2023 Airbyte, Inc., all rights reserved. +# + +from queue import Queue +from unittest.mock import Mock + +from airbyte_cdk.sources.streams.concurrent.partition_reader import PartitionReader +from airbyte_cdk.sources.streams.concurrent.partitions.record import Record +from airbyte_cdk.sources.streams.concurrent.partitions.types import PartitionCompleteSentinel + + +def test_partition_reader(): + queue = Queue() + partition_reader = PartitionReader(queue) + + stream_partition = Mock() + records = [ + Record({"id": 1, "name": "Jack"}), + Record({"id": 2, "name": "John"}), + ] + stream_partition.read.return_value = iter(records) + + partition_reader.process_partition(stream_partition) + + actual_records = [] + while record := queue.get(): + if isinstance(record, PartitionCompleteSentinel): + break + actual_records.append(record) + + assert records == actual_records diff --git a/airbyte-cdk/python/unit_tests/sources/streams/concurrent/test_thread_based_concurrent_stream.py b/airbyte-cdk/python/unit_tests/sources/streams/concurrent/test_thread_based_concurrent_stream.py new file mode 100644 index 000000000000..8b5dd8d0fe8d --- /dev/null +++ b/airbyte-cdk/python/unit_tests/sources/streams/concurrent/test_thread_based_concurrent_stream.py @@ -0,0 +1,131 @@ +# +# Copyright (c) 2023 Airbyte, Inc., all rights reserved. +# + +import unittest +from unittest.mock import Mock, call + +from airbyte_cdk.models import AirbyteStream, SyncMode +from airbyte_cdk.sources.streams.concurrent.availability_strategy import STREAM_AVAILABLE +from airbyte_cdk.sources.streams.concurrent.partitions.partition import Partition +from airbyte_cdk.sources.streams.concurrent.partitions.record import Record +from airbyte_cdk.sources.streams.concurrent.thread_based_concurrent_stream import ThreadBasedConcurrentStream + + +class ThreadBasedConcurrentStreamTest(unittest.TestCase): + def setUp(self): + self._partition_generator = Mock() + self._max_workers = 1 + self._name = "name" + self._json_schema = {} + self._availability_strategy = Mock() + self._primary_key = [] + self._cursor_field = None + self._slice_logger = Mock() + self._logger = Mock() + self._message_repository = Mock() + self._stream = ThreadBasedConcurrentStream( + self._partition_generator, + self._max_workers, + self._name, + self._json_schema, + self._availability_strategy, + self._primary_key, + self._cursor_field, + self._slice_logger, + self._logger, + self._message_repository, + 1, + 2, + 0, + ) + + def test_get_json_schema(self): + json_schema = self._stream.get_json_schema() + assert json_schema == self._json_schema + + def test_check_availability(self): + self._availability_strategy.check_availability.return_value = STREAM_AVAILABLE + availability = self._stream.check_availability() + assert availability == STREAM_AVAILABLE + self._availability_strategy.check_availability.assert_called_once_with(self._logger) + + def test_check_for_error_raises_no_exception_if_all_futures_succeeded(self): + futures = [Mock() for _ in range(3)] + for f in futures: + f.exception.return_value = None + + self._stream._check_for_errors(futures) + + def test_check_for_error_raises_an_exception_if_any_of_the_futures_raised_an_exception(self): + futures = [Mock() for _ in range(3)] + for f in futures: + f.exception.return_value = None + futures[0].exception.return_value = Exception("error") + + with self.assertRaises(Exception): + self._stream._check_for_errors(futures) + + def test_check_for_error_raises_an_exception_if_any_of_the_futures_are_not_done(self): + futures = [Mock() for _ in range(3)] + for f in futures: + f.exception.return_value = None + futures[0].done.return_value = False + + with self.assertRaises(Exception): + self._stream._check_for_errors(futures) + + def test_read_no_slice_message(self): + partition = Mock(spec=Partition) + expected_records = [Record({"id": 1}), Record({"id": "2"})] + partition.read.return_value = expected_records + partition.to_slice.return_value = {"slice": "slice"} + self._slice_logger.should_log_slice_message.return_value = False + + self._partition_generator.generate.return_value = [partition] + actual_records = list(self._stream.read()) + + assert expected_records == actual_records + + self._message_repository.emit_message.assert_not_called() + + def test_read_log_slice_message(self): + partition = Mock(spec=Partition) + expected_records = [Record({"id": 1}), Record({"id": "2"})] + partition.read.return_value = expected_records + partition.to_slice.return_value = {"slice": "slice"} + self._slice_logger.should_log_slice_message.return_value = True + slice_log_message = Mock() + self._slice_logger.create_slice_log_message.return_value = slice_log_message + + self._partition_generator.generate.return_value = [partition] + list(self._stream.read()) + + self._message_repository.emit_message.assert_called_once_with(slice_log_message) + + def test_as_airbyte_stream(self): + expected_airbyte_stream = AirbyteStream( + name=self._name, + json_schema=self._json_schema, + supported_sync_modes=[SyncMode.full_refresh], + source_defined_cursor=None, + default_cursor_field=None, + source_defined_primary_key=None, + namespace=None, + ) + actual_airbyte_stream = self._stream.as_airbyte_stream() + + assert expected_airbyte_stream == actual_airbyte_stream + + def test_wait_while_task_queue_is_full(self): + f1 = Mock() + f2 = Mock() + + # Verify that the done() method will be called until only one future is still running + f1.done.side_effect = [False, False] + f2.done.side_effect = [False, True] + futures = [f1, f2] + self._stream._wait_while_too_many_pending_futures(futures) + + f1.done.assert_has_calls([call(), call()]) + f2.done.assert_has_calls([call(), call()]) diff --git a/airbyte-cdk/python/unit_tests/sources/streams/test_stream_read.py b/airbyte-cdk/python/unit_tests/sources/streams/test_stream_read.py new file mode 100644 index 000000000000..84c7982c3975 --- /dev/null +++ b/airbyte-cdk/python/unit_tests/sources/streams/test_stream_read.py @@ -0,0 +1,193 @@ +# +# Copyright (c) 2023 Airbyte, Inc., all rights reserved. +# + +import logging +from typing import Any, Iterable, List, Mapping, Optional, Union +from unittest.mock import Mock + +import pytest +from airbyte_cdk.models import AirbyteLogMessage, AirbyteMessage, Level, SyncMode +from airbyte_cdk.models import Type as MessageType +from airbyte_cdk.sources.message import InMemoryMessageRepository +from airbyte_cdk.sources.streams import Stream +from airbyte_cdk.sources.streams.concurrent.adapters import StreamFacade +from airbyte_cdk.sources.streams.core import StreamData +from airbyte_cdk.sources.utils.schema_helpers import InternalConfig +from airbyte_cdk.sources.utils.slice_logger import DebugSliceLogger + +_A_CURSOR_FIELD = ["NESTED", "CURSOR"] +_DEFAULT_INTERNAL_CONFIG = InternalConfig() +_STREAM_NAME = "STREAM" + + +class _MockStream(Stream): + def __init__(self, slice_to_records: Mapping[str, List[Mapping[str, Any]]]): + self._slice_to_records = slice_to_records + + @property + def primary_key(self) -> Optional[Union[str, List[str], List[List[str]]]]: + return None + + def stream_slices( + self, *, sync_mode: SyncMode, cursor_field: Optional[List[str]] = None, stream_state: Optional[Mapping[str, Any]] = None + ) -> Iterable[Optional[Mapping[str, Any]]]: + for partition in self._slice_to_records.keys(): + yield {"partition": partition} + + def read_records( + self, + sync_mode: SyncMode, + cursor_field: Optional[List[str]] = None, + stream_slice: Optional[Mapping[str, Any]] = None, + stream_state: Optional[Mapping[str, Any]] = None, + ) -> Iterable[StreamData]: + yield from self._slice_to_records[stream_slice["partition"]] + + def get_json_schema(self) -> Mapping[str, Any]: + return {} + + +def _stream(slice_to_partition_mapping, slice_logger, logger, message_repository): + return _MockStream(slice_to_partition_mapping) + + +def _concurrent_stream(slice_to_partition_mapping, slice_logger, logger, message_repository): + stream = _stream(slice_to_partition_mapping, slice_logger, logger, message_repository) + source = Mock() + source._slice_logger = slice_logger + source.message_repository = message_repository + stream = StreamFacade.create_from_stream(stream, source, logger, 1) + stream.logger.setLevel(logger.level) + return stream + + +@pytest.mark.parametrize( + "constructor", + [ + pytest.param(_stream, id="synchronous_reader"), + pytest.param(_concurrent_stream, id="concurrent_reader"), + ], +) +def test_full_refresh_read_a_single_slice_with_debug(constructor): + # This test verifies that a concurrent stream adapted from a Stream behaves the same as the Stream object. + # It is done by running the same test cases on both streams + records = [ + {"id": 1, "partition": 1}, + {"id": 2, "partition": 1}, + ] + slice_to_partition = {1: records} + slice_logger = DebugSliceLogger() + logger = _mock_logger(True) + message_repository = InMemoryMessageRepository(Level.DEBUG) + stream = constructor(slice_to_partition, slice_logger, logger, message_repository) + + expected_records = [ + AirbyteMessage( + type=MessageType.LOG, + log=AirbyteLogMessage( + level=Level.INFO, + message='slice:{"partition": 1}', + ), + ), + *records, + ] + + actual_records = _read(stream, logger, slice_logger, message_repository) + + assert expected_records == actual_records + + +@pytest.mark.parametrize( + "constructor", + [ + pytest.param(_stream, id="synchronous_reader"), + pytest.param(_concurrent_stream, id="concurrent_reader"), + ], +) +def test_full_refresh_read_a_single_slice(constructor): + # This test verifies that a concurrent stream adapted from a Stream behaves the same as the Stream object. + # It is done by running the same test cases on both streams + logger = _mock_logger() + slice_logger = DebugSliceLogger() + message_repository = InMemoryMessageRepository(Level.INFO) + + records = [ + {"id": 1, "partition": 1}, + {"id": 2, "partition": 1}, + ] + slice_to_partition = {1: records} + stream = constructor(slice_to_partition, slice_logger, logger, message_repository) + + expected_records = [*records] + + actual_records = _read(stream, logger, slice_logger, message_repository) + + assert expected_records == actual_records + + +@pytest.mark.parametrize( + "constructor", + [ + pytest.param(_stream, id="synchronous_reader"), + pytest.param(_concurrent_stream, id="concurrent_reader"), + ], +) +def test_full_refresh_read_a_two_slices(constructor): + # This test verifies that a concurrent stream adapted from a Stream behaves the same as the Stream object + # It is done by running the same test cases on both streams + logger = _mock_logger() + slice_logger = DebugSliceLogger() + message_repository = InMemoryMessageRepository(Level.INFO) + + records_partition_1 = [ + {"id": 1, "partition": 1}, + {"id": 2, "partition": 1}, + ] + records_partition_2 = [ + {"id": 3, "partition": 2}, + {"id": 4, "partition": 2}, + ] + slice_to_partition = {1: records_partition_1, 2: records_partition_2} + stream = constructor(slice_to_partition, slice_logger, logger, message_repository) + + expected_records = [ + *records_partition_1, + *records_partition_2, + ] + + actual_records = _read(stream, logger, slice_logger, message_repository) + + for record in expected_records: + assert record in actual_records + assert len(expected_records) == len(actual_records) + + +def _read(stream, logger, slice_logger, message_repository): + records = [] + for record in stream.read_full_refresh(_A_CURSOR_FIELD, logger, slice_logger): + for message in message_repository.consume_queue(): + records.append(message) + records.append(record) + return records + + +def _mock_partition_generator(name: str, slices, records_per_partition, *, available=True, debug_log=False): + stream = Mock() + stream.name = name + stream.get_json_schema.return_value = {} + stream.generate_partitions.return_value = iter(slices) + stream.read_records.side_effect = [iter(records) for records in records_per_partition] + stream.logger.isEnabledFor.return_value = debug_log + if available: + stream.check_availability.return_value = True, None + else: + stream.check_availability.return_value = False, "A reason why the stream is unavailable" + return stream + + +def _mock_logger(enabled_for_debug=False): + logger = Mock() + logger.isEnabledFor.return_value = enabled_for_debug + logger.level = logging.DEBUG if enabled_for_debug else logging.INFO + return logger