diff --git a/airbyte-cdk/python/airbyte_cdk/entrypoint.py b/airbyte-cdk/python/airbyte_cdk/entrypoint.py index f89e0ef0ec29..e2fb56a82a59 100644 --- a/airbyte-cdk/python/airbyte_cdk/entrypoint.py +++ b/airbyte-cdk/python/airbyte_cdk/entrypoint.py @@ -207,10 +207,14 @@ def _emit_queued_messages(self, source: Source) -> Iterable[AirbyteMessage]: return -def launch(source: Source, args: List[str]) -> None: +def get_source_iter(source: Source, args: List[str]) -> Iterable[str]: source_entrypoint = AirbyteEntrypoint(source) parsed_args = source_entrypoint.parse_args(args) - for message in source_entrypoint.run(parsed_args): + return source_entrypoint.run(parsed_args) + + +def launch(source: Source, args: List[str]) -> None: + for message in get_source_iter(source, args): print(message) diff --git a/airbyte-cdk/python/airbyte_cdk/sources/abstract_source.py b/airbyte-cdk/python/airbyte_cdk/sources/abstract_source.py index 0f8bf716cc10..854bdff11a74 100644 --- a/airbyte-cdk/python/airbyte_cdk/sources/abstract_source.py +++ b/airbyte-cdk/python/airbyte_cdk/sources/abstract_source.py @@ -4,7 +4,17 @@ import logging from abc import ABC, abstractmethod -from typing import Any, Dict, Iterable, Iterator, List, Mapping, MutableMapping, Optional, Tuple, Union +from typing import ( + Any, + Iterable, + Iterator, + List, + Mapping, + MutableMapping, + Optional, + Tuple, + Union, +) from airbyte_cdk.models import ( AirbyteCatalog, @@ -28,7 +38,9 @@ from airbyte_cdk.sources.utils.schema_helpers import InternalConfig, split_config from airbyte_cdk.sources.utils.slice_logger import DebugSliceLogger, SliceLogger from airbyte_cdk.utils.event_timing import create_timer -from airbyte_cdk.utils.stream_status_utils import as_airbyte_message as stream_status_as_airbyte_message +from airbyte_cdk.utils.stream_status_utils import ( + as_airbyte_message as stream_status_as_airbyte_message, +) from airbyte_cdk.utils.traced_exception import AirbyteTracedException _default_message_repository = InMemoryMessageRepository() @@ -41,7 +53,9 @@ class AbstractSource(Source, ABC): """ @abstractmethod - def check_connection(self, logger: logging.Logger, config: Mapping[str, Any]) -> Tuple[bool, Optional[Any]]: + def check_connection( + self, logger: logging.Logger, config: Mapping[str, Any] + ) -> Tuple[bool, Optional[Any]]: """ :param logger: source logger :param config: The user-provided configuration as specified by the source's spec. @@ -62,7 +76,7 @@ def streams(self, config: Mapping[str, Any]) -> List[Stream]: """ # Stream name to instance map for applying output object transformation - _stream_to_instance_map: Dict[str, Stream] = {} + _stream_to_instance_map: Mapping[str, Stream] = {} _slice_logger: SliceLogger = DebugSliceLogger() @property @@ -70,14 +84,18 @@ def name(self) -> str: """Source name""" return self.__class__.__name__ - def discover(self, logger: logging.Logger, config: Mapping[str, Any]) -> AirbyteCatalog: + def discover( + self, logger: logging.Logger, config: Mapping[str, Any] + ) -> AirbyteCatalog: """Implements the Discover operation from the Airbyte Specification. See https://docs.airbyte.com/understanding-airbyte/airbyte-protocol/#discover. """ streams = [stream.as_airbyte_stream() for stream in self.streams(config=config)] return AirbyteCatalog(streams=streams) - def check(self, logger: logging.Logger, config: Mapping[str, Any]) -> AirbyteConnectionStatus: + def check( + self, logger: logging.Logger, config: Mapping[str, Any] + ) -> AirbyteConnectionStatus: """Implements the Check Connection operation from the Airbyte Specification. See https://docs.airbyte.com/understanding-airbyte/airbyte-protocol/#check. """ @@ -91,7 +109,9 @@ def read( logger: logging.Logger, config: Mapping[str, Any], catalog: ConfiguredAirbyteCatalog, - state: Optional[Union[List[AirbyteStateMessage], MutableMapping[str, Any]]] = None, + state: Optional[ + Union[List[AirbyteStateMessage], MutableMapping[str, Any]] + ] = None, ) -> Iterator[AirbyteMessage]: """Implements the Read operation from the Airbyte Specification. See https://docs.airbyte.com/understanding-airbyte/airbyte-protocol/.""" logger.info(f"Starting syncing {self.name}") @@ -99,7 +119,9 @@ def read( # TODO assert all streams exist in the connector # get the streams once in case the connector needs to make any queries to generate them stream_instances = {s.name: s for s in self.streams(config)} - state_manager = ConnectorStateManager(stream_instance_map=stream_instances, state=state) + state_manager = ConnectorStateManager( + stream_instance_map=stream_instances, state=state + ) self._stream_to_instance_map = stream_instances stream_name_to_exception: MutableMapping[str, AirbyteTracedException] = {} @@ -117,12 +139,20 @@ def read( try: timer.start_event(f"Syncing stream {configured_stream.stream.name}") - stream_is_available, reason = stream_instance.check_availability(logger, self) + stream_is_available, reason = stream_instance.check_availability( + logger, self + ) if not stream_is_available: - logger.warning(f"Skipped syncing stream '{stream_instance.name}' because it was unavailable. {reason}") + logger.warning( + f"Skipped syncing stream '{stream_instance.name}' because it was unavailable. {reason}" + ) continue - logger.info(f"Marking stream {configured_stream.stream.name} as STARTED") - yield stream_status_as_airbyte_message(configured_stream.stream, AirbyteStreamStatus.STARTED) + logger.info( + f"Marking stream {configured_stream.stream.name} as STARTED" + ) + yield stream_status_as_airbyte_message( + configured_stream.stream, AirbyteStreamStatus.STARTED + ) yield from self._read_stream( logger=logger, stream_instance=stream_instance, @@ -130,22 +160,36 @@ def read( state_manager=state_manager, internal_config=internal_config, ) - logger.info(f"Marking stream {configured_stream.stream.name} as STOPPED") - yield stream_status_as_airbyte_message(configured_stream.stream, AirbyteStreamStatus.COMPLETE) + logger.info( + f"Marking stream {configured_stream.stream.name} as STOPPED" + ) + yield stream_status_as_airbyte_message( + configured_stream.stream, AirbyteStreamStatus.COMPLETE + ) except AirbyteTracedException as e: - yield stream_status_as_airbyte_message(configured_stream.stream, AirbyteStreamStatus.INCOMPLETE) + yield stream_status_as_airbyte_message( + configured_stream.stream, AirbyteStreamStatus.INCOMPLETE + ) if self.continue_sync_on_stream_failure: stream_name_to_exception[stream_instance.name] = e else: raise e except Exception as e: yield from self._emit_queued_messages() - logger.exception(f"Encountered an exception while reading stream {configured_stream.stream.name}") - logger.info(f"Marking stream {configured_stream.stream.name} as STOPPED") - yield stream_status_as_airbyte_message(configured_stream.stream, AirbyteStreamStatus.INCOMPLETE) + logger.exception( + f"Encountered an exception while reading stream {configured_stream.stream.name}" + ) + logger.info( + f"Marking stream {configured_stream.stream.name} as STOPPED" + ) + yield stream_status_as_airbyte_message( + configured_stream.stream, AirbyteStreamStatus.INCOMPLETE + ) display_message = stream_instance.get_error_display_message(e) if display_message: - raise AirbyteTracedException.from_exception(e, message=display_message) from e + raise AirbyteTracedException.from_exception( + e, message=display_message + ) from e raise e finally: timer.finish_event() @@ -153,7 +197,11 @@ def read( logger.info(timer.report()) if self.continue_sync_on_stream_failure and len(stream_name_to_exception) > 0: - raise AirbyteTracedException(message=self._generate_failed_streams_error_message(stream_name_to_exception)) + raise AirbyteTracedException( + message=self._generate_failed_streams_error_message( + stream_name_to_exception + ) + ) logger.info(f"Finished syncing {self.name}") @property @@ -173,7 +221,9 @@ def _read_stream( internal_config: InternalConfig, ) -> Iterator[AirbyteMessage]: if internal_config.page_size and isinstance(stream_instance, HttpStream): - logger.info(f"Setting page size for {stream_instance.name} to {internal_config.page_size}") + logger.info( + f"Setting page size for {stream_instance.name} to {internal_config.page_size}" + ) stream_instance.page_size = internal_config.page_size logger.debug( f"Syncing configured stream: {configured_stream.stream.name}", @@ -185,7 +235,10 @@ def _read_stream( ) stream_instance.log_stream_sync_configuration() - use_incremental = configured_stream.sync_mode == SyncMode.incremental and stream_instance.supports_incremental + use_incremental = ( + configured_stream.sync_mode == SyncMode.incremental + and stream_instance.supports_incremental + ) if use_incremental: record_iterator = self._read_incremental( logger, @@ -195,7 +248,9 @@ def _read_stream( internal_config, ) else: - record_iterator = self._read_full_refresh(logger, stream_instance, configured_stream, internal_config) + record_iterator = self._read_full_refresh( + logger, stream_instance, configured_stream, internal_config + ) record_counter = 0 stream_name = configured_stream.stream.name @@ -206,7 +261,9 @@ def _read_stream( if record_counter == 1: logger.info(f"Marking stream {stream_name} as RUNNING") # If we just read the first record of the stream, emit the transition to the RUNNING state - yield stream_status_as_airbyte_message(configured_stream.stream, AirbyteStreamStatus.RUNNING) + yield stream_status_as_airbyte_message( + configured_stream.stream, AirbyteStreamStatus.RUNNING + ) yield from self._emit_queued_messages() yield record @@ -230,7 +287,9 @@ def _read_incremental( :return: """ stream_name = configured_stream.stream.name - stream_state = state_manager.get_stream_state(stream_name, stream_instance.namespace) + stream_state = state_manager.get_stream_state( + stream_name, stream_instance.namespace + ) if stream_state and "state" in dir(stream_instance): stream_instance.state = stream_state # type: ignore # we check that state in the dir(stream_instance) @@ -260,7 +319,9 @@ def _read_full_refresh( internal_config: InternalConfig, ) -> Iterator[AirbyteMessage]: total_records_counter = 0 - for record_data_or_message in stream_instance.read_full_refresh(configured_stream.cursor_field, logger, self._slice_logger): + for record_data_or_message in stream_instance.read_full_refresh( + configured_stream.cursor_field, logger, self._slice_logger + ): message = self._get_message(record_data_or_message, stream_instance) yield message if message.type == MessageType.RECORD: @@ -268,14 +329,21 @@ def _read_full_refresh( if internal_config.is_limit_reached(total_records_counter): return - def _get_message(self, record_data_or_message: Union[StreamData, AirbyteMessage], stream: Stream) -> AirbyteMessage: + def _get_message( + self, record_data_or_message: Union[StreamData, AirbyteMessage], stream: Stream + ) -> AirbyteMessage: """ Converts the input to an AirbyteMessage if it is a StreamData. Returns the input as is if it is already an AirbyteMessage """ if isinstance(record_data_or_message, AirbyteMessage): return record_data_or_message else: - return stream_data_to_airbyte_message(stream.name, record_data_or_message, stream.transformer, stream.get_json_schema()) + return stream_data_to_airbyte_message( + stream.name, + record_data_or_message, + stream.transformer, + stream.get_json_schema(), + ) @property def message_repository(self) -> Union[None, MessageRepository]: @@ -293,6 +361,13 @@ def continue_sync_on_stream_failure(self) -> bool: return False @staticmethod - def _generate_failed_streams_error_message(stream_failures: Mapping[str, AirbyteTracedException]) -> str: - failures = ", ".join([f"{stream}: {exception.__repr__()}" for stream, exception in stream_failures.items()]) + def _generate_failed_streams_error_message( + stream_failures: Mapping[str, AirbyteTracedException] + ) -> str: + failures = ", ".join( + [ + f"{stream}: {exception.__repr__()}" + for stream, exception in stream_failures.items() + ] + ) return f"During the sync, the following streams did not sync successfully: {failures}" diff --git a/airbyte-cdk/python/airbyte_cdk/sources/async_cdk/__init__.py b/airbyte-cdk/python/airbyte_cdk/sources/async_cdk/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/airbyte-cdk/python/airbyte_cdk/sources/async_cdk/abstract_source_async.py b/airbyte-cdk/python/airbyte_cdk/sources/async_cdk/abstract_source_async.py new file mode 100644 index 000000000000..a0d93a02185d --- /dev/null +++ b/airbyte-cdk/python/airbyte_cdk/sources/async_cdk/abstract_source_async.py @@ -0,0 +1,233 @@ +# +# Copyright (c) 2023 Airbyte, Inc., all rights reserved. +# + +import logging +from abc import ABC, abstractmethod +from typing import ( + Any, + AsyncGenerator, + Iterable, + Iterator, + List, + Mapping, + MutableMapping, + Optional, + Tuple, + Union, +) + +from airbyte_cdk.models import ( + AirbyteConnectionStatus, + AirbyteMessage, + AirbyteStateMessage, + AirbyteStreamStatus, + ConfiguredAirbyteCatalog, + ConfiguredAirbyteStream, + Status, + SyncMode, +) +from airbyte_cdk.models import Type as MessageType +from airbyte_cdk.sources.connector_state_manager import ConnectorStateManager +from airbyte_cdk.sources.abstract_source import AbstractSource +from airbyte_cdk.sources.async_cdk.streams.core_async import AsyncStream +from airbyte_cdk.sources.streams.core import StreamData +from airbyte_cdk.sources.streams.http.http import HttpStream +from airbyte_cdk.sources.utils.record_helper import stream_data_to_airbyte_message +from airbyte_cdk.sources.utils.schema_helpers import InternalConfig +from airbyte_cdk.utils.stream_status_utils import ( + as_airbyte_message as stream_status_as_airbyte_message, +) + + +class AsyncAbstractSource(AbstractSource, ABC): + """ + Abstract base class for an Airbyte Source. Consumers should implement any abstract methods + in this class to create an Airbyte Specification compliant Source. + """ + + @abstractmethod + async def check_connection( + self, logger: logging.Logger, config: Mapping[str, Any] + ) -> Tuple[bool, Optional[Any]]: + """ + :param logger: source logger + :param config: The user-provided configuration as specified by the source's spec. + This usually contains information required to check connection e.g. tokens, secrets and keys etc. + :return: A tuple of (boolean, error). If boolean is true, then the connection check is successful + and we can connect to the underlying data source using the provided configuration. + Otherwise, the input config cannot be used to connect to the underlying data source, + and the "error" object should describe what went wrong. + The error object will be cast to string to display the problem to the user. + """ + + async def check( + self, logger: logging.Logger, config: Mapping[str, Any] + ) -> AirbyteConnectionStatus: + """Implements the Check Connection operation from the Airbyte Specification. + See https://docs.airbyte.com/understanding-airbyte/airbyte-protocol/#check. + """ + check_succeeded, error = await self.check_connection(logger, config) + if not check_succeeded: + return AirbyteConnectionStatus(status=Status.FAILED, message=repr(error)) + return AirbyteConnectionStatus(status=Status.SUCCEEDED) + + @abstractmethod + async def streams(self, config: Mapping[str, Any]) -> List[AsyncStream]: + """ + :param config: The user-provided configuration as specified by the source's spec. + Any stream construction related operation should happen here. + :return: A list of the streams in this source connector. + """ + + async def read( + self, + logger: logging.Logger, + config: Mapping[str, Any], + catalog: ConfiguredAirbyteCatalog, + state: Optional[ + Union[List[AirbyteStateMessage], MutableMapping[str, Any]] + ] = None, + ) -> Iterator[AirbyteMessage]: + """ + Implements the Read operation from the Airbyte Specification. See https://docs.airbyte.com/understanding-airbyte/airbyte-protocol/. + + This method is not used when the AsyncSource is used in conjunction with the AsyncSourceDispatcher. + """ + ... + + async def read_stream( + self, + logger: logging.Logger, + stream_instance: AsyncStream, + configured_stream: ConfiguredAirbyteStream, + state_manager: ConnectorStateManager, + internal_config: InternalConfig, + ) -> AsyncGenerator[AirbyteMessage, None]: + if internal_config.page_size and isinstance(stream_instance, HttpStream): + logger.info( + f"Setting page size for {stream_instance.name} to {internal_config.page_size}" + ) + stream_instance.page_size = internal_config.page_size + logger.debug( + f"Syncing configured stream: {configured_stream.stream.name}", + extra={ + "sync_mode": configured_stream.sync_mode, + "primary_key": configured_stream.primary_key, + "cursor_field": configured_stream.cursor_field, + }, + ) + stream_instance.log_stream_sync_configuration() + + use_incremental = ( + configured_stream.sync_mode == SyncMode.incremental + and stream_instance.supports_incremental + ) + if use_incremental: + record_iterator = self._read_incremental( + logger, + stream_instance, + configured_stream, + state_manager, + internal_config, + ) + else: + record_iterator = self._read_full_refresh( + logger, stream_instance, configured_stream, internal_config + ) + + record_counter = 0 + stream_name = configured_stream.stream.name + logger.info(f"Syncing stream: {stream_name} ") + async for record in record_iterator: + if record.type == MessageType.RECORD: + record_counter += 1 + if record_counter == 1: + logger.info(f"Marking stream {stream_name} as RUNNING") + # If we just read the first record of the stream, emit the transition to the RUNNING state + yield stream_status_as_airbyte_message( + configured_stream.stream, AirbyteStreamStatus.RUNNING + ) + for message in self._emit_queued_messages(): + yield message + yield record + + logger.info(f"Read {record_counter} records from {stream_name} stream") + + async def _read_incremental( + self, + logger: logging.Logger, + stream_instance: AsyncStream, + configured_stream: ConfiguredAirbyteStream, + state_manager: ConnectorStateManager, + internal_config: InternalConfig, + ) -> AsyncGenerator[AirbyteMessage, None]: + """Read stream using incremental algorithm + + :param logger: + :param stream_instance: + :param configured_stream: + :param state_manager: + :param internal_config: + :return: + """ + stream_name = configured_stream.stream.name + stream_state = state_manager.get_stream_state( + stream_name, stream_instance.namespace + ) + + if stream_state and "state" in dir(stream_instance): + stream_instance.state = stream_state # type: ignore # we check that state in the dir(stream_instance) + logger.info(f"Setting state of {self.name} stream to {stream_state}") + + async for record_data_or_message in stream_instance.read_incremental( + configured_stream.cursor_field, + logger, + self._slice_logger, + stream_state, + state_manager, + self.per_stream_state_enabled, + internal_config, + ): + yield self._get_message(record_data_or_message, stream_instance) + + def _emit_queued_messages(self) -> Iterable[AirbyteMessage]: + if self.message_repository: + yield from self.message_repository.consume_queue() + return + + async def _read_full_refresh( + self, + logger: logging.Logger, + stream_instance: AsyncStream, + configured_stream: ConfiguredAirbyteStream, + internal_config: InternalConfig, + ) -> AsyncGenerator[AirbyteMessage, None]: + total_records_counter = 0 + async for record_data_or_message in stream_instance.read_full_refresh( + configured_stream.cursor_field, logger, self._slice_logger + ): + message = self._get_message(record_data_or_message, stream_instance) + yield message + if message.type == MessageType.RECORD: + total_records_counter += 1 + if internal_config.is_limit_reached(total_records_counter): + return + + def _get_message( + self, + record_data_or_message: Union[StreamData, AirbyteMessage], + stream: AsyncStream, + ) -> AirbyteMessage: + """ + Converts the input to an AirbyteMessage if it is a StreamData. Returns the input as is if it is already an AirbyteMessage + """ + if isinstance(record_data_or_message, AirbyteMessage): + return record_data_or_message + else: + return stream_data_to_airbyte_message( + stream.name, + record_data_or_message, + stream.transformer, + stream.get_json_schema(), + ) diff --git a/airbyte-cdk/python/airbyte_cdk/sources/async_cdk/source_dispatcher.py b/airbyte-cdk/python/airbyte_cdk/sources/async_cdk/source_dispatcher.py new file mode 100644 index 000000000000..18dd323a7c94 --- /dev/null +++ b/airbyte-cdk/python/airbyte_cdk/sources/async_cdk/source_dispatcher.py @@ -0,0 +1,410 @@ +# +# Copyright (c) 2023 Airbyte, Inc., all rights reserved. +# + +import asyncio +import logging +from abc import ABC +from queue import Queue +from typing import ( + Any, + Dict, + Iterator, + List, + Mapping, + MutableMapping, + Optional, + Tuple, + Union, +) + +from airbyte_cdk.models import ( + AirbyteMessage, + AirbyteStateMessage, + AirbyteStreamStatus, + ConfiguredAirbyteCatalog, + ConfiguredAirbyteStream, + ConnectorSpecification, + SyncMode, +) +from airbyte_cdk.sources.connector_state_manager import ConnectorStateManager +from airbyte_cdk.sources.abstract_source import AbstractSource +from airbyte_cdk.sources.async_cdk.abstract_source_async import AsyncAbstractSource +from airbyte_cdk.sources.async_cdk.source_reader import Sentinel, SourceReader +from airbyte_cdk.sources.async_cdk.streams.core_async import AsyncStream +from airbyte_cdk.sources.streams import Stream +from airbyte_cdk.sources.utils.schema_helpers import InternalConfig, split_config +from airbyte_cdk.utils.event_timing import create_timer +from airbyte_cdk.utils.traced_exception import AirbyteTracedException +from airbyte_cdk.utils.stream_status_utils import ( + as_airbyte_message as stream_status_as_airbyte_message, +) + +DEFAULT_QUEUE_SIZE = 10000 +DEFAULT_SESSION_LIMIT = 10000 +DEFAULT_TIMEOUT = None + + +class SourceDispatcher(AbstractSource, ABC): + """ + Abstract base class for an Airbyte Source that can dispatch to an async source. + """ + + def __init__(self, async_source: AsyncAbstractSource, source: AbstractSource): + self.async_source = async_source + self.source = source + self.queue = Queue(DEFAULT_QUEUE_SIZE) + self.session_limit = DEFAULT_SESSION_LIMIT + + def check_connection( + self, logger: logging.Logger, config: Mapping[str, Any] + ) -> Tuple[bool, Optional[Any]]: + """ + Run the async_source's `check_connection` method on the event loop. + """ + loop = asyncio.get_event_loop() + return loop.run_until_complete( + self.async_source.check_connection(logger, config) + ) + + def streams(self, config: Mapping[str, Any]) -> List[AsyncStream]: + """ + Run the async_source's `streams` method on the event loop. + """ + loop = asyncio.get_event_loop() + return loop.run_until_complete(self.async_source.streams(config)) + + def spec(self, logger: logging.Logger) -> ConnectorSpecification: + """ + Run the async_source's `spec` method. + """ + return self.async_source.spec(logger) + + def read( + self, + logger: logging.Logger, + config: Mapping[str, Any], + catalog: ConfiguredAirbyteCatalog, + state: Optional[Union[List[AirbyteStateMessage], MutableMapping[str, Any]]] = None, + ): + logger.info(f"Starting syncing {self.name}") + config, internal_config = split_config(config) + stream_instances: Mapping[str, AsyncStream] = { + s.name: s for s in self.streams(config) + } + state_manager = ConnectorStateManager( + stream_instance_map=stream_instances, state=state + ) + self._stream_to_instance_map = stream_instances + self._assert_streams(catalog, stream_instances) + n_records = 0 + stream_name_to_exception: MutableMapping[str, AirbyteTracedException] = {} # TODO: wire this option through for asyncio + full_refresh_streams = {} + incremental_streams = {} + + for stream in catalog.streams: + if stream.sync_mode == SyncMode.full_refresh: + full_refresh_streams[stream.stream.name] = stream_instances[stream.stream.name] + else: + incremental_streams[stream.stream.name] = stream_instances[stream.stream.name] + + with create_timer(self.name) as timer: + for record in self._read_async_source( + catalog, full_refresh_streams, timer, logger, state_manager, internal_config + ): + n_records += 1 + yield record + + for record in self._read_sync_source( + catalog, incremental_streams, timer, logger, state_manager, internal_config, stream_name_to_exception + ): + n_records += 1 + yield record + + print(f"_______________________-ASYNCIO SOURCE N RECORDS == {n_records}") + + if self.continue_sync_on_stream_failure and len(stream_name_to_exception) > 0: + raise AirbyteTracedException( + message=self._generate_failed_streams_error_message( + stream_name_to_exception + ) + ) + + logger.info(f"Finished syncing {self.name}") + + def _read_sync_source( + self, + catalog: ConfiguredAirbyteCatalog, + stream_instances: Dict[str, AsyncStream], + timer: Any, + logger: logging.Logger, + state_manager: ConnectorStateManager, + internal_config: InternalConfig, + stream_name_to_exception: MutableMapping[str, AirbyteTracedException], + ): + """ + For concurrent streams, records from the sync source. + + TODO: this can be deleted when asyncio is deployed for incremental + """ + for configured_stream in catalog.streams: + stream_instance = stream_instances.get(configured_stream.stream.name) + if not stream_instance: + if not self.raise_exception_on_missing_stream: + continue + raise KeyError( + f"The stream {configured_stream.stream.name} no longer exists in the configuration. " + f"Refresh the schema in replication settings and remove this stream from future sync attempts." + ) + + try: + timer.start_event(f"Syncing stream {configured_stream.stream.name}") + stream_is_available, reason = stream_instance.check_availability( + logger, self + ) + if not stream_is_available: + logger.warning( + f"Skipped syncing stream '{stream_instance.name}' because it was unavailable. {reason}" + ) + continue + logger.info( + f"Marking stream {configured_stream.stream.name} as STARTED" + ) + yield stream_status_as_airbyte_message( + configured_stream.stream, AirbyteStreamStatus.STARTED + ) + yield from self._read_stream( + logger=logger, + stream_instance=stream_instance, + configured_stream=configured_stream, + state_manager=state_manager, + internal_config=internal_config, + ) + logger.info( + f"Marking stream {configured_stream.stream.name} as STOPPED" + ) + yield stream_status_as_airbyte_message( + configured_stream.stream, AirbyteStreamStatus.COMPLETE + ) + except AirbyteTracedException as e: + yield stream_status_as_airbyte_message( + configured_stream.stream, AirbyteStreamStatus.INCOMPLETE + ) + if self.continue_sync_on_stream_failure: + stream_name_to_exception[stream_instance.name] = e + else: + raise e + except Exception as e: + yield from self._emit_queued_messages() + logger.exception( + f"Encountered an exception while reading stream {configured_stream.stream.name}" + ) + logger.info( + f"Marking stream {configured_stream.stream.name} as STOPPED" + ) + yield stream_status_as_airbyte_message( + configured_stream.stream, AirbyteStreamStatus.INCOMPLETE + ) + display_message = stream_instance.get_error_display_message(e) + if display_message: + raise AirbyteTracedException.from_exception( + e, message=display_message + ) from e + raise e + finally: + timer.finish_event() + logger.info(f"Finished syncing {configured_stream.stream.name}") + logger.info(timer.report()) + + def _read_async_source( + self, + catalog: ConfiguredAirbyteCatalog, + stream_instances: Dict[str, AsyncStream], + timer: Any, + logger: logging.Logger, + state_manager: ConnectorStateManager, + internal_config: InternalConfig, + ) -> Iterator[AirbyteMessage]: + """ + Run the async_source's `read_streams` method and yield its results. + """ + streams_in_progress_sentinels = { + s.stream.name: Sentinel(s.stream.name) + for s in catalog.streams + if s.stream.name in stream_instances + } + if not streams_in_progress_sentinels: + return + self.reader = SourceReader( + logger, + self.queue, + streams_in_progress_sentinels, + self._read_streams, + catalog, + stream_instances, + timer, + logger, + state_manager, + internal_config, + ) + for record in self.reader: + yield record + + for record in self.reader.drain(): + if isinstance(record, Exception): + raise record + yield record + + def _assert_streams( + self, + catalog: ConfiguredAirbyteCatalog, + stream_instances: Dict[str, AsyncStream], + ): + for configured_stream in catalog.streams: + stream_instance = stream_instances.get(configured_stream.stream.name) + if not stream_instance: + if not self.async_source.raise_exception_on_missing_stream: + return + raise KeyError( + f"The stream {configured_stream.stream.name} no longer exists in the configuration. " + f"Refresh the schema in replication settings and remove this stream from future sync attempts." + ) + + async def _read_streams( + self, + catalog: ConfiguredAirbyteCatalog, + stream_instances: Dict[str, AsyncStream], + timer: Any, + logger: logging.Logger, + state_manager: ConnectorStateManager, + internal_config: InternalConfig, + ): + pending_tasks = set() + n_started, n_streams = 0, len(catalog.streams) + streams_iterator = iter(catalog.streams) + exceptions = False + + while (pending_tasks or n_started < n_streams) and not exceptions: + while len(pending_tasks) < self.session_limit and ( + configured_stream := next(streams_iterator, None) + ): + if configured_stream is None: + break + stream_instance = stream_instances.get(configured_stream.stream.name) + stream = stream_instances.get(configured_stream.stream.name) + self.reader.sessions[ + configured_stream.stream.name + ] = await stream.ensure_session() + pending_tasks.add( + asyncio.create_task( + self._do_async_read_stream( + configured_stream, + stream_instance, + timer, + logger, + state_manager, + internal_config, + ) + ) + ) + n_started += 1 + + done, pending_tasks = await asyncio.wait( + pending_tasks, return_when=asyncio.FIRST_COMPLETED + ) + + for task in done: + if exc := task.exception(): + for remaining_task in pending_tasks: + await remaining_task.cancel() + self.queue.put(exc) + exceptions = True + + async def _do_async_read_stream( + self, + configured_stream: ConfiguredAirbyteStream, + stream_instance: AsyncStream, + timer: Any, + logger: logging.Logger, + state_manager: ConnectorStateManager, + internal_config: InternalConfig, + ): + try: + await self._async_read_stream( + configured_stream, + stream_instance, + timer, + logger, + state_manager, + internal_config, + ) + finally: + self.queue.put(Sentinel(configured_stream.stream.name)) + + async def _async_read_stream( + self, + configured_stream: ConfiguredAirbyteStream, + stream_instance: AsyncStream, + timer: Any, + logger: logging.Logger, + state_manager: ConnectorStateManager, + internal_config: InternalConfig, + ): + try: + timer.start_event(f"Syncing stream {configured_stream.stream.name}") + stream_is_available, reason = await stream_instance.check_availability( + logger, self.async_source + ) + if not stream_is_available: + logger.warning( + f"Skipped syncing stream '{stream_instance.name}' because it was unavailable. {reason}" + ) + return + logger.info(f"Marking stream {configured_stream.stream.name} as STARTED") + self.queue.put( + stream_status_as_airbyte_message( + configured_stream.stream, AirbyteStreamStatus.STARTED + ) + ) + async for record in self.async_source.read_stream( + logger=logger, + stream_instance=stream_instance, + configured_stream=configured_stream, + state_manager=state_manager, + internal_config=internal_config, + ): + self.queue.put(record) + logger.info(f"Marking stream {configured_stream.stream.name} as STOPPED") + self.queue.put( + stream_status_as_airbyte_message( + configured_stream.stream, AirbyteStreamStatus.COMPLETE + ) + ) + except AirbyteTracedException as e: + self.queue.put( + stream_status_as_airbyte_message( + configured_stream.stream, AirbyteStreamStatus.INCOMPLETE + ) + ) + raise e + except Exception as e: + for message in self._emit_queued_messages(): + self.queue.put(message) + logger.exception( + f"Encountered an exception while reading stream {configured_stream.stream.name}" + ) + logger.info(f"Marking stream {configured_stream.stream.name} as STOPPED") + self.queue.put( + stream_status_as_airbyte_message( + configured_stream.stream, AirbyteStreamStatus.INCOMPLETE + ) + ) + display_message = stream_instance.get_error_display_message(e) + if display_message: + raise AirbyteTracedException.from_exception(e, message=display_message) + else: + raise e + finally: + timer.finish_event() + logger.info(f"Finished syncing {configured_stream.stream.name}") + # logger.info(timer.report()) # TODO - this is causing scenario-based test failures diff --git a/airbyte-cdk/python/airbyte_cdk/sources/async_cdk/source_reader.py b/airbyte-cdk/python/airbyte_cdk/sources/async_cdk/source_reader.py new file mode 100644 index 000000000000..02882d9758c2 --- /dev/null +++ b/airbyte-cdk/python/airbyte_cdk/sources/async_cdk/source_reader.py @@ -0,0 +1,66 @@ +import asyncio +import logging +from queue import Queue +from threading import Thread +from typing import Any, Callable, Dict, Iterator + +import aiohttp + +DEFAULT_TIMEOUT = None + + +class Sentinel: + def __init__(self, name: str): + self.name = name + + +class SourceReader(Iterator): + def __init__(self, logger: logging.Logger, queue: Queue, sentinels: Dict[str, Sentinel], reader_fn: Callable, *args: Any): + self.logger = logger + self.queue = queue + self.sentinels = sentinels + self.reader_fn = reader_fn + self.reader_args = args + self.sessions: Dict[str, aiohttp.ClientSession] = {} + + self.thread = Thread(target=self._start_reader_thread) + self.thread.start() + + def _start_reader_thread(self): + asyncio.run(self.reader_fn(*self.reader_args)) + + def __next__(self): + loop = asyncio.get_event_loop() + try: + item = self.queue.get(timeout=DEFAULT_TIMEOUT) + if isinstance(item, Exception): + self.logger.error(f"An error occurred in the async thread: {item}") + self.thread.join() + raise item + if isinstance(item, Sentinel): + # Sessions can only be closed once items in the stream have all been dequeued + if session := self.sessions.pop(item.name, None): + loop.create_task(session.close()) # TODO: this can be done better + try: + self.sentinels.pop(item.name) + except KeyError: + raise RuntimeError(f"The sentinel for stream {item.name} was already dequeued. This is unexpected and indicates a possible problem with the connector. Please contact Support.") + if not self.sentinels: + self.thread.join() + raise StopIteration + else: + return self.__next__() + else: + return item + finally: + loop.create_task(self.cleanup()) + + def drain(self): + while not self.queue.empty(): + yield self.queue.get() + self.thread.join() + + async def cleanup(self): + pass + # for session in self.sessions.values(): + # await session.close() diff --git a/airbyte-cdk/python/airbyte_cdk/sources/async_cdk/streams/__init__.py b/airbyte-cdk/python/airbyte_cdk/sources/async_cdk/streams/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/airbyte-cdk/python/airbyte_cdk/sources/async_cdk/streams/async_call_rate.py b/airbyte-cdk/python/airbyte_cdk/sources/async_cdk/streams/async_call_rate.py new file mode 100644 index 000000000000..bb0669120f8c --- /dev/null +++ b/airbyte-cdk/python/airbyte_cdk/sources/async_cdk/streams/async_call_rate.py @@ -0,0 +1,46 @@ +# +# Copyright (c) 2023 Airbyte, Inc., all rights reserved. +# + +import logging +from typing import Any, Optional + +import aiohttp +import aiohttp_client_cache + +from airbyte_cdk.sources.streams.call_rate import AbstractAPIBudget + +MIXIN_BASE = aiohttp.ClientSession + +logger = logging.getLogger("airbyte") + + +class AsyncLimiterMixin(MIXIN_BASE): + """Mixin class that adds rate-limiting behavior to requests.""" + + def __init__( + self, + api_budget: Optional[AbstractAPIBudget], + **kwargs: Any, + ): + self._api_budget = api_budget + super().__init__(**kwargs) # type: ignore # Base Session doesn't take any kwargs + + async def send( + self, request: aiohttp.ClientRequest, **kwargs: Any + ) -> aiohttp.ClientResponse: + """Send a request with rate-limiting.""" + assert ( + self._api_budget is None + ), "API budgets are not supported in the async CDK yet." + return await super().send(request, **kwargs) # type: ignore # MIXIN_BASE should be used with aiohttp.ClientSession + + +class AsyncLimiterSession(AsyncLimiterMixin, aiohttp.ClientSession): + """Session that adds rate-limiting behavior to requests.""" + + +class AsyncCachedLimiterSession( + aiohttp_client_cache.CachedSession, AsyncLimiterMixin, aiohttp.ClientSession +): + """Session class with caching and rate-limiting behavior.""" diff --git a/airbyte-cdk/python/airbyte_cdk/sources/async_cdk/streams/availability_strategy_async.py b/airbyte-cdk/python/airbyte_cdk/sources/async_cdk/streams/availability_strategy_async.py new file mode 100644 index 000000000000..23377887b95f --- /dev/null +++ b/airbyte-cdk/python/airbyte_cdk/sources/async_cdk/streams/availability_strategy_async.py @@ -0,0 +1,35 @@ +# +# Copyright (c) 2023 Airbyte, Inc., all rights reserved. +# + +import logging +from abc import ABC, abstractmethod +from typing import Optional, Tuple + +from airbyte_cdk.sources.async_cdk.abstract_source_async import AsyncAbstractSource +from airbyte_cdk.sources.async_cdk.streams.core_async import AsyncStream + + +class AsyncAvailabilityStrategy(ABC): + """ + Abstract base class for checking stream availability. + """ + + @abstractmethod + async def check_availability( + self, + stream: AsyncStream, + logger: logging.Logger, + source: Optional["AsyncAbstractSource"], + ) -> Tuple[bool, Optional[str]]: + """ + Checks stream availability. + + :param stream: stream + :param logger: source logger + :param source: (optional) source + :return: A tuple of (boolean, str). If boolean is true, then the stream + is available, and no str is required. Otherwise, the stream is unavailable + for some reason and the str should describe what went wrong and how to + resolve the unavailability, if possible. + """ diff --git a/airbyte-cdk/python/airbyte_cdk/sources/async_cdk/streams/core_async.py b/airbyte-cdk/python/airbyte_cdk/sources/async_cdk/streams/core_async.py new file mode 100644 index 000000000000..d1ba29a14770 --- /dev/null +++ b/airbyte-cdk/python/airbyte_cdk/sources/async_cdk/streams/core_async.py @@ -0,0 +1,183 @@ +# +# Copyright (c) 2023 Airbyte, Inc., all rights reserved. +# + + +import logging +from abc import ABC, abstractmethod +from typing import ( + TYPE_CHECKING, + Any, + AsyncGenerator, + List, + Mapping, + MutableMapping, + Optional, + Tuple, +) + +from airbyte_cdk.models import SyncMode +from airbyte_cdk.models import Type as MessageType +from airbyte_cdk.sources.streams.core import Stream, StreamData +from airbyte_cdk.sources.utils.schema_helpers import InternalConfig +from airbyte_cdk.sources.utils.slice_logger import SliceLogger + +if TYPE_CHECKING: + from airbyte_cdk.sources.async_cdk.abstract_source_async import AsyncAbstractSource + from airbyte_cdk.sources.async_cdk.streams.http.availability_strategy_async import ( + AsyncHttpAvailabilityStrategy, + ) + + +class AsyncStream(Stream, ABC): + """ + Base abstract class for an Airbyte Stream. Makes no assumption of the Stream's underlying transport protocol. + """ + + async def read_full_refresh( + self, + cursor_field: Optional[List[str]], + logger: logging.Logger, + slice_logger: SliceLogger, + ) -> AsyncGenerator[StreamData, None]: + async for _slice in self.stream_slices( + sync_mode=SyncMode.full_refresh, cursor_field=cursor_field + ): + logger.debug( + f"Processing stream slices for {self.name} (sync_mode: full_refresh)", + extra={"stream_slice": _slice}, + ) + if slice_logger.should_log_slice_message(logger): + yield slice_logger.create_slice_log_message(_slice) + async for record in self.read_records( + stream_slice=_slice, + sync_mode=SyncMode.full_refresh, + cursor_field=cursor_field, + ): + yield record + + async def read_incremental( # type: ignore # ignoring typing for ConnectorStateManager because of circular dependencies + self, + cursor_field: Optional[List[str]], + logger: logging.Logger, + slice_logger: SliceLogger, + stream_state: MutableMapping[str, Any], + state_manager, + per_stream_state_enabled: bool, + internal_config: InternalConfig, + ) -> AsyncGenerator[StreamData, None]: + slices = self.stream_slices( + cursor_field=cursor_field, + sync_mode=SyncMode.incremental, + stream_state=stream_state, + ) + logger.debug( + f"Processing stream slices for {self.name} (sync_mode: incremental)", + extra={"stream_slices": slices}, + ) + + has_slices = False + record_counter = 0 + async for _slice in slices: + has_slices = True + if slice_logger.should_log_slice_message(logger): + yield slice_logger.create_slice_log_message(_slice) + records = self.read_records( + sync_mode=SyncMode.incremental, + stream_slice=_slice, + stream_state=stream_state, + cursor_field=cursor_field or None, + ) + async for record_data_or_message in records: + yield record_data_or_message + if isinstance(record_data_or_message, Mapping) or ( + hasattr(record_data_or_message, "type") + and record_data_or_message.type == MessageType.RECORD + ): + record_data = ( + record_data_or_message + if isinstance(record_data_or_message, Mapping) + else record_data_or_message.record + ) + stream_state = self.get_updated_state(stream_state, record_data) + checkpoint_interval = self.state_checkpoint_interval + record_counter += 1 + if ( + checkpoint_interval + and record_counter % checkpoint_interval == 0 + ): + yield self._checkpoint_state( + stream_state, state_manager, per_stream_state_enabled + ) + + if internal_config.is_limit_reached(record_counter): + break + + yield self._checkpoint_state(stream_state, state_manager, per_stream_state_enabled) + + if not has_slices: + # Safety net to ensure we always emit at least one state message even if there are no slices + checkpoint = self._checkpoint_state(stream_state, state_manager, per_stream_state_enabled) + yield checkpoint + + @abstractmethod + async def read_records( + self, + sync_mode: SyncMode, + cursor_field: Optional[List[str]] = None, + stream_slice: Optional[Mapping[str, Any]] = None, + stream_state: Optional[Mapping[str, Any]] = None, + ) -> AsyncGenerator[StreamData, None]: + """ + This method should be overridden by subclasses to read records based on the inputs + """ + ... + + async def stream_slices( + self, + *, + sync_mode: SyncMode, + cursor_field: Optional[List[str]] = None, + stream_state: Optional[Mapping[str, Any]] = None, + ) -> AsyncGenerator[Optional[Mapping[str, Any]], None]: + """ + Override to define the slices for this stream. See the stream slicing section of the docs for more information. + + :param sync_mode: + :param cursor_field: + :param stream_state: + :return: + """ + yield None + + async def ensure_session(self, *args: Any, **kwargs: Any) -> Any: + """ + Override to define a session object on the stream. + """ + pass + + @property + def availability_strategy(self) -> Optional["AsyncHttpAvailabilityStrategy"]: + """ + :return: The AvailabilityStrategy used to check whether this stream is available. + """ + return None + + async def check_availability( + self, logger: logging.Logger, source: Optional["AsyncAbstractSource"] = None + ) -> Tuple[bool, Optional[str]]: + """ + Checks whether this stream is available. + + :param logger: source logger + :param source: (optional) source + :return: A tuple of (boolean, str). If boolean is true, then this stream + is available, and no str is required. Otherwise, this stream is unavailable + for some reason and the str should describe what went wrong and how to + resolve the unavailability, if possible. + """ + if self.availability_strategy: + return await self.availability_strategy.check_availability( + self, logger, source + ) + return True, None diff --git a/airbyte-cdk/python/airbyte_cdk/sources/async_cdk/streams/http/__init__.py b/airbyte-cdk/python/airbyte_cdk/sources/async_cdk/streams/http/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/airbyte-cdk/python/airbyte_cdk/sources/async_cdk/streams/http/availability_strategy_async.py b/airbyte-cdk/python/airbyte_cdk/sources/async_cdk/streams/http/availability_strategy_async.py new file mode 100644 index 000000000000..5f428b0e3bb4 --- /dev/null +++ b/airbyte-cdk/python/airbyte_cdk/sources/async_cdk/streams/http/availability_strategy_async.py @@ -0,0 +1,73 @@ +# +# Copyright (c) 2023 Airbyte, Inc., all rights reserved. +# + +import logging +from typing import TYPE_CHECKING, Optional, Tuple + +from airbyte_cdk.sources.async_cdk.abstract_source_async import AsyncAbstractSource +from airbyte_cdk.sources.async_cdk.streams.utils.stream_helper_async import ( + get_first_record_for_slice, + get_first_stream_slice, +) +from airbyte_cdk.sources.streams.http.availability_strategy import ( + HttpAvailabilityStrategy, +) +from airbyte_cdk.sources.streams.http.utils import HttpError + +if TYPE_CHECKING: + from airbyte_cdk.sources.async_cdk.streams.http.http_async import AsyncHttpStream + + +class AsyncHttpAvailabilityStrategy(HttpAvailabilityStrategy): + async def check_availability( + self, + stream: "AsyncHttpStream", + logger: logging.Logger, + source: Optional["AsyncAbstractSource"], + ) -> Tuple[bool, Optional[str]]: + """ + Check stream availability by attempting to read the first record of the + stream. + + :param stream: stream + :param logger: source logger + :param source: (optional) source + :return: A tuple of (boolean, str). If boolean is true, then the stream + is available, and no str is required. Otherwise, the stream is unavailable + for some reason and the str should describe what went wrong and how to + resolve the unavailability, if possible. + """ + try: + # Some streams need a stream slice to read records (e.g. if they have a SubstreamPartitionRouter) + # Streams that don't need a stream slice will return `None` as their first stream slice. + stream_slice = await get_first_stream_slice(stream) + except StopAsyncIteration: + # If stream_slices has no `next()` item (Note - this is different from stream_slices returning [None]!) + # This can happen when a substream's `stream_slices` method does a `for record in parent_records: yield ` + # without accounting for the case in which the parent stream is empty. + reason = f"Cannot attempt to connect to stream {stream.name} - no stream slices were found, likely because the parent stream is empty." + return False, reason + except HttpError as error: + is_available, reason = self._handle_http_error(stream, logger, source, error) + if not is_available: + reason = f"Unable to get slices for {stream.name} stream, because of error in parent stream. {reason}" + return is_available, reason + + try: + await get_first_record_for_slice(stream, stream_slice) + return True, None + except StopAsyncIteration: + logger.info( + f"Successfully connected to stream {stream.name}, but got 0 records." + ) + return True, None + except HttpError as error: + is_available, reason = self._handle_http_error( + stream, logger, source, error + ) + if not is_available: + reason = f"Unable to read {stream.name} stream. {reason}" + return is_available, reason + + return True, None diff --git a/airbyte-cdk/python/airbyte_cdk/sources/async_cdk/streams/http/exceptions_async.py b/airbyte-cdk/python/airbyte_cdk/sources/async_cdk/streams/http/exceptions_async.py new file mode 100644 index 000000000000..f0c08bd85f95 --- /dev/null +++ b/airbyte-cdk/python/airbyte_cdk/sources/async_cdk/streams/http/exceptions_async.py @@ -0,0 +1,42 @@ +# +# Copyright (c) 2023 Airbyte, Inc., all rights reserved. +# + + +from typing import Union + +from airbyte_cdk.sources.streams.http.utils import HttpError +from airbyte_cdk.sources.streams.http.exceptions import AbstractBaseBackoffException + + +class AsyncBaseBackoffException(AbstractBaseBackoffException, HttpError): + def __init__(self, error: HttpError, error_message: str = ""): + error_message = ( + error_message + or f"Request URL: {error.url}, Response Code: {error.status_code}, Response Text: {error.text}" + ) + super().__init__(aiohttp_error=error._aiohttp_error, error_message=error_message) + + +class AsyncUserDefinedBackoffException(AsyncBaseBackoffException): + """ + An exception that exposes how long it attempted to backoff + """ + + def __init__( + self, + backoff: Union[int, float], + error: HttpError, + error_message: str = "", + ): + """ + :param backoff: how long to backoff in seconds + :param request: the request that triggered this backoff exception + :param response: the response that triggered the backoff exception + """ + self.backoff = backoff + super().__init__(error, error_message=error_message) + + +class AsyncDefaultBackoffException(AsyncBaseBackoffException): + pass diff --git a/airbyte-cdk/python/airbyte_cdk/sources/async_cdk/streams/http/http_async.py b/airbyte-cdk/python/airbyte_cdk/sources/async_cdk/streams/http/http_async.py new file mode 100644 index 000000000000..01343b6d8f7d --- /dev/null +++ b/airbyte-cdk/python/airbyte_cdk/sources/async_cdk/streams/http/http_async.py @@ -0,0 +1,532 @@ +# +# Copyright (c) 2023 Airbyte, Inc., all rights reserved. +# + +import json +import logging +import os +from abc import ABC, abstractmethod +from pathlib import Path +from typing import ( + Any, + AsyncGenerator, + Callable, + Dict, + List, + Mapping, + Optional, + Tuple, + Union, +) +from yarl import URL + +import aiohttp +import aiohttp_client_cache +import requests +from airbyte_cdk.models import SyncMode +from airbyte_cdk.sources.async_cdk.streams.async_call_rate import ( + AsyncCachedLimiterSession, + AsyncLimiterSession, +) +from airbyte_cdk.sources.async_cdk.streams.core_async import AsyncStream +from airbyte_cdk.sources.async_cdk.streams.http.availability_strategy_async import ( + AsyncHttpAvailabilityStrategy, +) +from airbyte_cdk.sources.async_cdk.streams.http.exceptions_async import ( + AsyncDefaultBackoffException, + AsyncUserDefinedBackoffException, +) +from airbyte_cdk.sources.http_config import MAX_CONNECTION_POOL_SIZE +from airbyte_cdk.sources.streams.core import StreamData +from airbyte_cdk.sources.streams.http.auth import NoAuth +from airbyte_cdk.sources.streams.http.auth.core import HttpAuthenticator +from airbyte_cdk.sources.streams.http.exceptions import RequestBodyException +from airbyte_cdk.sources.streams.http.http_base import BaseHttpStream +from airbyte_cdk.sources.streams.http.rate_limiting import default_backoff_handler, async_user_defined_backoff_handler +from airbyte_cdk.sources.streams.http.utils import HttpError +from airbyte_cdk.utils.constants import ENV_REQUEST_CACHE_PATH + + +# list of all possible HTTP methods which can be used for sending of request bodies +BODY_REQUEST_METHODS = ("GET", "POST", "PUT", "PATCH") +RecordsGeneratorFunction = Callable[ + [ + aiohttp.ClientRequest, + aiohttp.ClientResponse, + Mapping[str, Any], + Optional[Mapping[str, Any]], + ], + AsyncGenerator[StreamData, None], +] + + +class AsyncHttpStream(BaseHttpStream, AsyncStream, ABC): + """ + Base abstract class for an Airbyte Stream using the HTTP protocol with asyncio. + + Basic building block for users building an Airbyte source for an async HTTP API. + """ + + def __init__( + self, authenticator: Optional[Union[HttpAuthenticator, NoAuth]] = NoAuth() + ): + # TODO: wire in rate limiting via ApiBudget + self._api_budget = None + self._session: Optional[aiohttp.ClientSession] = None + # TODO: HttpStream handles other authentication codepaths, which may need to be added later + self._authenticator = authenticator + + @property + def authenticator(self) -> Optional[Union[HttpAuthenticator, NoAuth]]: + # TODO: this behaves differently than http.py, which would return None if self._authenticator is an HttpAuthenticator. + # But, it looks like this property is only used here in http_async.py and Salesforce's streams.py. + # It doesn't appear to be causing any problems with Salesforce. + return self._authenticator + + @property + def availability_strategy(self) -> Optional[AsyncHttpAvailabilityStrategy]: + return AsyncHttpAvailabilityStrategy() + + def request_session(self) -> aiohttp.ClientSession: + """ + Session factory based on use_cache property and call rate limits (api_budget parameter) + :return: instance of request-based session + """ + connector = aiohttp.TCPConnector( + limit_per_host=MAX_CONNECTION_POOL_SIZE, + limit=MAX_CONNECTION_POOL_SIZE, + ) + kwargs = {} + + if self._authenticator: + kwargs["headers"] = self._authenticator.get_auth_header() + + if self.use_cache: + cache_dir = os.getenv(ENV_REQUEST_CACHE_PATH) + # Use in-memory cache if cache_dir is not set + # This is a non-obvious interface, but it ensures we don't write sql files when running unit tests + if cache_dir: + sqlite_path = str(Path(cache_dir) / self.cache_filename) + else: + sqlite_path = "file::memory:?cache=shared" + cache = aiohttp_client_cache.SQLiteBackend( + cache_name=sqlite_path, + allowed_methods=( + "get", + "post", + "put", + "patch", + "options", + "delete", + "list", + ), + ) + return AsyncCachedLimiterSession( + cache=cache, connector=connector, api_budget=self._api_budget + ) + else: + return AsyncLimiterSession( + connector=connector, api_budget=self._api_budget, **kwargs + ) + + async def clear_cache(self) -> None: + """ + Clear cached requests for current session, can be called any time + """ + if isinstance(self._session, aiohttp_client_cache.CachedSession): + await self._session.cache.clear() + + @abstractmethod + async def next_page_token( + self, response: aiohttp.ClientResponse + ) -> Optional[Mapping[str, Any]]: + """ + Override this method to define a pagination strategy. + + The value returned from this method is passed to most other methods in this class. Use it to form a request e.g: set headers or query params. + + :return: The token for the next page from the input response object. Returning None means there are no more pages to read in this response. + """ + + @abstractmethod + async def parse_response( + self, + response: aiohttp.ClientResponse, + *, + stream_state: Mapping[str, Any], + stream_slice: Optional[Mapping[str, Any]] = None, + next_page_token: Optional[Mapping[str, Any]] = None, + ) -> AsyncGenerator[Mapping[str, Any], None]: + """ + Parses the raw response object into a list of records. + By default, this returns an iterable containing the input. Override to parse differently. + :param response: + :param stream_state: + :param stream_slice: + :param next_page_token: + :return: An iterable containing the parsed response + """ + ... + + # TODO move all the retry logic to a functor/decorator which is input as an init parameter + def should_retry(self, response: aiohttp.ClientResponse) -> bool: + """ + Override to set different conditions for backoff based on the response from the server. + + By default, back off on the following HTTP response statuses: + - 429 (Too Many Requests) indicating rate limiting + - 500s to handle transient server errors + + Unexpected but transient exceptions (connection timeout, DNS resolution failed, etc..) are retried by default. + """ + return response.status == 429 or 500 <= response.status < 600 + + def backoff_time(self, response: aiohttp.ClientResponse) -> Optional[float]: + """ + Override this method to dynamically determine backoff time e.g: by reading the X-Retry-After header. + + This method is called only if should_backoff() returns True for the input request. + + :param response: + :return how long to backoff in seconds. The return value may be a floating point number for subsecond precision. Returning None defers backoff + to the default backoff behavior (e.g using an exponential algorithm). + """ + return None + + def error_message(self, response: aiohttp.ClientResponse) -> str: + """ + Override this method to specify a custom error message which can incorporate the HTTP response received + + :param response: The incoming HTTP response from the partner API + :return: + """ + return "" + + def _create_prepared_request( + self, + path: str, + headers: Optional[Dict[str, str]] = None, + params: Optional[Mapping[str, str]] = None, + json: Optional[Mapping[str, Any]] = None, + data: Optional[Union[str, Mapping[str, Any]]] = None, + ) -> aiohttp.ClientRequest: + return self._create_aiohttp_client_request(path, headers, params, json, data) + + def _create_aiohttp_client_request( + self, + path: str, + headers: Optional[Dict[str, str]] = None, + params: Optional[Mapping[str, str]] = None, + json_data: Optional[Mapping[str, Any]] = None, + data: Optional[Union[str, Mapping[str, Any]]] = None, + ) -> aiohttp.ClientRequest: + str_url = self._join_url(self.url_base, path) + # str_url = "http://localhost:8000" # TODO + url = URL(str_url) + if self.must_deduplicate_query_params(): + query_params = self.deduplicate_query_params(str_url, params) + else: + query_params = params or {} + if self.http_method.upper() in BODY_REQUEST_METHODS: + if json_data and data: + raise RequestBodyException( + "At the same time only one of the 'request_body_data' and 'request_body_json' functions can return data" + ) + elif json_data: + headers = headers or {} + headers.update({"Content-Type": "application/json"}) + data = json.dumps(json_data) + + client_request = aiohttp.ClientRequest( + self.http_method, url, headers=headers, params=query_params, data=data + ) + + return client_request + + async def _send( + self, request: aiohttp.ClientRequest, request_kwargs: Mapping[str, Any] + ) -> aiohttp.ClientResponse: + """ + Wraps sending the request in rate limit and error handlers. + Please note that error handling for HTTP status codes will be ignored if raise_on_http_errors is set to False + + This method handles two types of exceptions: + 1. Expected transient exceptions e.g: 429 status code. + 2. Unexpected transient exceptions e.g: timeout. + + To trigger a backoff, we raise an exception that is handled by the backoff decorator. If an exception is not handled by the decorator will + fail the sync. + + For expected transient exceptions, backoff time is determined by the type of exception raised: + 1. CustomBackoffException uses the user-provided backoff value + 2. DefaultBackoffException falls back on the decorator's default behavior e.g: exponential backoff + + Unexpected transient exceptions use the default backoff parameters. + Unexpected persistent exceptions are not handled and will cause the sync to fail. + """ + self.logger.debug( + "Making outbound API request", + extra={ + "headers": request.headers, + "url": request.url, + "request_body": request.body, + }, + ) + if self._session is None: + raise AssertionError( + "The session was not set before attempting to make a request. This is unexpected. Please contact Support." + ) + + response = await self._session.request( + request.method, + request.url, + headers=request.headers, + auth=request.auth, + **request_kwargs, + ) + + # Evaluation of response.text can be heavy, for example, if streaming a large response + # Do it only in debug mode + if self.logger.isEnabledFor(logging.DEBUG): + self.logger.debug( + "Receiving response", + extra={ + "headers": response.headers, + "status": response.status, + "body": response.text, + }, + ) + try: + return await self.handle_response_with_error(response) + except HttpError as exc: + if self.should_retry(response): + custom_backoff_time = self.backoff_time(response) + error_message = self.error_message(response) + if custom_backoff_time: + raise AsyncUserDefinedBackoffException( + backoff=custom_backoff_time, + error=exc, + error_message=error_message, + ) + else: + raise AsyncDefaultBackoffException( + error=exc, error_message=error_message + ) + elif self.raise_on_http_errors: + # Raise any HTTP exceptions that happened in case there were unexpected ones + raise exc + return response + + async def ensure_session(self) -> aiohttp.ClientSession: + if self._session is None: + self._session = self.request_session() + return self._session + + async def _send_request( + self, request: aiohttp.ClientRequest, request_kwargs: Mapping[str, Any] + ) -> aiohttp.ClientResponse: + """ + Creates backoff wrappers which are responsible for retry logic + """ + + """ + Backoff package has max_tries parameter that means total number of + tries before giving up, so if this number is 0 no calls expected to be done. + But for this class we call it max_REtries assuming there would be at + least one attempt and some retry attempts, to comply this logic we add + 1 to expected retries attempts. + """ + max_tries = self.max_retries + """ + According to backoff max_tries docstring: + max_tries: The maximum number of attempts to make before giving + up ...The default value of None means there is no limit to + the number of tries. + This implies that if max_tries is explicitly set to None there is no + limit to retry attempts, otherwise it is limited number of tries. But + this is not true for current version of backoff packages (1.8.0). Setting + max_tries to 0 or negative number would result in endless retry attempts. + Add this condition to avoid an endless loop if it hasn't been set + explicitly (i.e. max_retries is not None). + """ + max_time = self.max_time + """ + According to backoff max_time docstring: + max_time: The maximum total amount of time to try for before + giving up. Once expired, the exception will be allowed to + escape. If a callable is passed, it will be + evaluated at runtime and its return value used. + """ + if max_tries is not None: + max_tries = max(0, max_tries) + 1 + + user_backoff_handler = async_user_defined_backoff_handler( + max_tries=max_tries, max_time=max_time + )(self._send) + backoff_handler = default_backoff_handler( + max_tries=max_tries, max_time=max_time, factor=self.retry_factor + ) + return await backoff_handler(user_backoff_handler)(request, request_kwargs) + + @classmethod + def parse_response_error_message(cls, response: requests.Response) -> Optional[str]: + raise NotImplementedError( + "Async streams should use HttpError.parse_error_message" + ) + + async def read_records( + self, + sync_mode: SyncMode, + cursor_field: Optional[List[str]] = None, + stream_slice: Optional[Mapping[str, Any]] = None, + stream_state: Optional[Mapping[str, Any]] = None, + ) -> AsyncGenerator[StreamData, None]: + async def _records_generator_fn( + req, res, state, _slice + ) -> AsyncGenerator[StreamData, None]: + async for record in self.parse_response( + res, stream_slice=_slice, stream_state=state + ): + yield record + + async for record in self._read_pages( + _records_generator_fn, stream_slice, stream_state + ): + yield record + + async def _read_pages( + self, + records_generator_fn: RecordsGeneratorFunction, + stream_slice: Optional[Mapping[str, Any]] = None, + stream_state: Optional[Mapping[str, Any]] = None, + ) -> AsyncGenerator[StreamData, None]: + stream_state = stream_state or {} + pagination_complete = False + next_page_token = None + while not pagination_complete: + + async def f() -> Tuple[ + aiohttp.ClientRequest, + aiohttp.ClientResponse, + Optional[Mapping[str, Any]], + ]: + nonlocal next_page_token + request, response = await self._fetch_next_page( + stream_slice, stream_state, next_page_token + ) + next_page_token = await self.next_page_token(response) + return request, response, next_page_token + + request, response, next_page_token = await f() + + async for record in records_generator_fn( + request, response, stream_state, stream_slice + ): + yield record + + if not next_page_token: + pagination_complete = True + + async def _fetch_next_page( + self, + stream_slice: Optional[Mapping[str, Any]] = None, + stream_state: Optional[Mapping[str, Any]] = None, + next_page_token: Optional[Mapping[str, Any]] = None, + ) -> Tuple[aiohttp.ClientRequest, aiohttp.ClientResponse]: + request_headers = self.request_headers( + stream_state=stream_state, + stream_slice=stream_slice, + next_page_token=next_page_token, + ) + request = self._create_prepared_request( + path=self.path( + stream_state=stream_state, + stream_slice=stream_slice, + next_page_token=next_page_token, + ), + headers=dict( + request_headers, + **self.authenticator.get_auth_header() if self.authenticator else {}, + ), + params=self.request_params( + stream_state=stream_state, + stream_slice=stream_slice, + next_page_token=next_page_token, + ), + json=self.request_body_json( + stream_state=stream_state, + stream_slice=stream_slice, + next_page_token=next_page_token, + ), + data=self.request_body_data( + stream_state=stream_state, + stream_slice=stream_slice, + next_page_token=next_page_token, + ), + ) + request_kwargs = self.request_kwargs( + stream_state=stream_state, + stream_slice=stream_slice, + next_page_token=next_page_token, + ) + + response = await self._send_request(request, request_kwargs) + return request, response + + async def handle_response_with_error( + self, response: aiohttp.ClientResponse + ) -> aiohttp.ClientResponse: + """ + If the response has a non-ok status code, raise an exception, otherwise return the response. + + When raising an exception, attach response json data to exception object. + """ + if response.ok: + return response + + exc = HttpError( + aiohttp_error=aiohttp.ClientResponseError( + response.request_info, + response.history, + status=response.status, + message=response.reason or "", + headers=response.headers, + ), + ) + await exc.set_response_data(response) + text = await response.text() + self.logger.error(text) + raise exc + + +class AsyncHttpSubStream(AsyncHttpStream, ABC): + def __init__(self, parent: AsyncHttpStream, **kwargs: Any): + """ + :param parent: should be the instance of HttpStream class + """ + super().__init__(**kwargs) + self.parent = parent + + async def stream_slices( + self, + sync_mode: SyncMode, + cursor_field: Optional[List[str]] = None, + stream_state: Optional[Mapping[str, Any]] = None, + ) -> AsyncGenerator[Optional[Mapping[str, Any]], None]: + await self.parent.ensure_session() + # iterate over all parent stream_slices + async for stream_slice in self.parent.stream_slices( + sync_mode=SyncMode.full_refresh, + cursor_field=cursor_field, + stream_state=stream_state, + ): + parent_records = self.parent.read_records( + sync_mode=SyncMode.full_refresh, + cursor_field=cursor_field, + stream_slice=stream_slice, + stream_state=stream_state, + ) + + # iterate over all parent records with current stream_slice + async for record in parent_records: + yield {"parent": record} diff --git a/airbyte-cdk/python/airbyte_cdk/sources/async_cdk/streams/utils/__init__.py b/airbyte-cdk/python/airbyte_cdk/sources/async_cdk/streams/utils/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/airbyte-cdk/python/airbyte_cdk/sources/async_cdk/streams/utils/stream_helper_async.py b/airbyte-cdk/python/airbyte_cdk/sources/async_cdk/streams/utils/stream_helper_async.py new file mode 100644 index 000000000000..d9a2fb1ffffb --- /dev/null +++ b/airbyte-cdk/python/airbyte_cdk/sources/async_cdk/streams/utils/stream_helper_async.py @@ -0,0 +1,43 @@ +# +# Copyright (c) 2023 Airbyte, Inc., all rights reserved. +# + +from typing import Any, AsyncGenerator, Mapping, Optional + +from airbyte_cdk.models import SyncMode +from airbyte_cdk.sources.async_cdk.streams.core_async import AsyncStream +from airbyte_cdk.sources.streams.core import StreamData + + +async def get_first_stream_slice(stream: AsyncStream) -> Optional[Mapping[str, Any]]: + """ + Gets the first stream_slice from a given stream's stream_slices. + :param stream: stream + :raises StopAsyncIteration: if there is no first slice to return (the stream_slices generator is empty) + :return: first stream slice from 'stream_slices' generator (`None` is a valid stream slice) + """ + first_slice = await anext( + stream.stream_slices( + cursor_field=stream.cursor_field, + sync_mode=SyncMode.full_refresh, + ) + ) + return first_slice + + +async def get_first_record_for_slice( + stream: AsyncStream, stream_slice: Optional[Mapping[str, Any]] +) -> StreamData: + """ + Gets the first record for a stream_slice of a stream. + :param stream: stream + :param stream_slice: stream_slice + :raises StopAsyncIteration: if there is no first record to return (the read_records generator is empty) + :return: StreamData containing the first record in the slice + """ + record = await anext( + stream.read_records( + sync_mode=SyncMode.full_refresh, stream_slice=stream_slice + ) + ) + return record diff --git a/airbyte-cdk/python/airbyte_cdk/sources/streams/http/availability_strategy.py b/airbyte-cdk/python/airbyte_cdk/sources/streams/http/availability_strategy.py index 3f8755070c4b..fd1accab9674 100644 --- a/airbyte-cdk/python/airbyte_cdk/sources/streams/http/availability_strategy.py +++ b/airbyte-cdk/python/airbyte_cdk/sources/streams/http/availability_strategy.py @@ -4,20 +4,27 @@ import logging import typing -from typing import Dict, Optional, Tuple +from typing import Dict, Optional, Tuple, Union import requests from airbyte_cdk.sources.streams import Stream from airbyte_cdk.sources.streams.availability_strategy import AvailabilityStrategy -from airbyte_cdk.sources.streams.utils.stream_helper import get_first_record_for_slice, get_first_stream_slice +from airbyte_cdk.sources.streams.http.utils import HttpError +from airbyte_cdk.sources.streams.utils.stream_helper import ( + get_first_record_for_slice, + get_first_stream_slice, +) from requests import HTTPError if typing.TYPE_CHECKING: from airbyte_cdk.sources import Source + from airbyte_cdk.sources.streams.http import HttpStream class HttpAvailabilityStrategy(AvailabilityStrategy): - def check_availability(self, stream: Stream, logger: logging.Logger, source: Optional["Source"]) -> Tuple[bool, Optional[str]]: + def check_availability( + self, stream: Stream, logger: logging.Logger, source: Optional["Source"] + ) -> Tuple[bool, Optional[str]]: """ Check stream availability by attempting to read the first record of the stream. @@ -41,7 +48,7 @@ def check_availability(self, stream: Stream, logger: logging.Logger, source: Opt reason = f"Cannot attempt to connect to stream {stream.name} - no stream slices were found, likely because the parent stream is empty." return False, reason except HTTPError as error: - is_available, reason = self.handle_http_error(stream, logger, source, error) + is_available, reason = self._handle_http_error(stream, logger, source, error) if not is_available: reason = f"Unable to get slices for {stream.name} stream, because of error in parent stream. {reason}" return is_available, reason @@ -50,16 +57,22 @@ def check_availability(self, stream: Stream, logger: logging.Logger, source: Opt get_first_record_for_slice(stream, stream_slice) return True, None except StopIteration: - logger.info(f"Successfully connected to stream {stream.name}, but got 0 records.") + logger.info( + f"Successfully connected to stream {stream.name}, but got 0 records." + ) return True, None except HTTPError as error: - is_available, reason = self.handle_http_error(stream, logger, source, error) + is_available, reason = self._handle_http_error(stream, logger, source, error) if not is_available: reason = f"Unable to read {stream.name} stream. {reason}" return is_available, reason - def handle_http_error( - self, stream: Stream, logger: logging.Logger, source: Optional["Source"], error: HTTPError + def _handle_http_error( + self, + stream: Stream, + logger: logging.Logger, + source: Optional["Source"], + error: Union[HTTPError, HttpError], ) -> Tuple[bool, Optional[str]]: """ Override this method to define error handling for various `HTTPError`s @@ -77,22 +90,36 @@ def handle_http_error( for some reason and the str should describe what went wrong and how to resolve the unavailability, if possible. """ - status_code = error.response.status_code - known_status_codes = self.reasons_for_unavailable_status_codes(stream, logger, source, error) + if isinstance(error, HttpError): + status_code = error.status_code + url = error.url + reason = error.reason or error.error_message + else: + # TODO: wrap synchronous codepath's errors in HttpError to delete this path + status_code = error.response.status_code + url = error.response.url + reason = error.response.reason + known_status_codes = self.reasons_for_unavailable_status_codes( + stream, logger, source, error + ) known_reason = known_status_codes.get(status_code) if not known_reason: # If the HTTPError is not in the dictionary of errors we know how to handle, don't except raise error doc_ref = self._visit_docs_message(logger, source) - reason = f"The endpoint {error.response.url} returned {status_code}: {error.response.reason}. {known_reason}. {doc_ref} " - response_error_message = stream.parse_response_error_message(error.response) + reason = f"The endpoint {url} returned {status_code}: {reason}. {known_reason}. {doc_ref} " + response_error_message = stream.parse_error_message(error) if response_error_message: reason += response_error_message return False, reason def reasons_for_unavailable_status_codes( - self, stream: Stream, logger: logging.Logger, source: Optional["Source"], error: HTTPError + self, + stream: Stream, + logger: logging.Logger, + source: Optional["Source"], + error: HTTPError, ) -> Dict[int, str]: """ Returns a dictionary of HTTP status codes that indicate stream @@ -134,9 +161,13 @@ def _visit_docs_message(logger: logging.Logger, source: Optional["Source"]) -> s else: return "Please visit the connector's documentation to learn more. " - except FileNotFoundError: # If we are unit testing without implementing spec() method in source + except ( + FileNotFoundError + ): # If we are unit testing without implementing spec() method in source if source: - docs_url = f"https://docs.airbyte.com/integrations/sources/{source.name}" + docs_url = ( + f"https://docs.airbyte.com/integrations/sources/{source.name}" + ) else: docs_url = "https://docs.airbyte.com/integrations/sources/test" diff --git a/airbyte-cdk/python/airbyte_cdk/sources/streams/http/exceptions.py b/airbyte-cdk/python/airbyte_cdk/sources/streams/http/exceptions.py index a97884b53f8a..6a0c7f3f4fbf 100644 --- a/airbyte-cdk/python/airbyte_cdk/sources/streams/http/exceptions.py +++ b/airbyte-cdk/python/airbyte_cdk/sources/streams/http/exceptions.py @@ -3,12 +3,17 @@ # +from abc import ABC from typing import Union import requests -class BaseBackoffException(requests.exceptions.HTTPError): +class AbstractBaseBackoffException(ABC): + pass + + +class BaseBackoffException(AbstractBaseBackoffException, requests.exceptions.HTTPError): def __init__(self, request: requests.PreparedRequest, response: requests.Response, error_message: str = ""): error_message = ( error_message or f"Request URL: {request.url}, Response Code: {response.status_code}, Response Text: {response.text}" diff --git a/airbyte-cdk/python/airbyte_cdk/sources/streams/http/http.py b/airbyte-cdk/python/airbyte_cdk/sources/streams/http/http.py index e5784cd25c03..4183a3815d0e 100644 --- a/airbyte-cdk/python/airbyte_cdk/sources/streams/http/http.py +++ b/airbyte-cdk/python/airbyte_cdk/sources/streams/http/http.py @@ -5,46 +5,73 @@ import logging import os -import urllib from abc import ABC, abstractmethod from pathlib import Path -from typing import Any, Callable, Iterable, List, Mapping, MutableMapping, Optional, Tuple, Union -from urllib.parse import urljoin +from typing import ( + Any, + Callable, + Iterable, + List, + Mapping, + Optional, + Tuple, + Union, +) import requests import requests_cache from airbyte_cdk.models import SyncMode from airbyte_cdk.sources.http_config import MAX_CONNECTION_POOL_SIZE from airbyte_cdk.sources.streams.availability_strategy import AvailabilityStrategy -from airbyte_cdk.sources.streams.call_rate import APIBudget, CachedLimiterSession, LimiterSession -from airbyte_cdk.sources.streams.core import Stream, StreamData -from airbyte_cdk.sources.streams.http.availability_strategy import HttpAvailabilityStrategy -from airbyte_cdk.sources.utils.types import JsonType +from airbyte_cdk.sources.streams.call_rate import ( + APIBudget, + CachedLimiterSession, + LimiterSession, +) +from airbyte_cdk.sources.streams.core import StreamData +from airbyte_cdk.sources.streams.http.availability_strategy import ( + HttpAvailabilityStrategy, +) +from airbyte_cdk.sources.streams.http.http_base import BaseHttpStream +from airbyte_cdk.sources.streams.http.utils import HttpError from airbyte_cdk.utils.constants import ENV_REQUEST_CACHE_PATH from requests.auth import AuthBase from .auth.core import HttpAuthenticator, NoAuth -from .exceptions import DefaultBackoffException, RequestBodyException, UserDefinedBackoffException +from .exceptions import ( + DefaultBackoffException, + RequestBodyException, + UserDefinedBackoffException, +) from .rate_limiting import default_backoff_handler, user_defined_backoff_handler # list of all possible HTTP methods which can be used for sending of request bodies BODY_REQUEST_METHODS = ("GET", "POST", "PUT", "PATCH") -class HttpStream(Stream, ABC): +class HttpStream(BaseHttpStream, ABC): """ Base abstract class for an Airbyte Stream using the HTTP protocol. Basic building block for users building an Airbyte source for a HTTP API. """ - source_defined_cursor = True # Most HTTP streams use a source defined cursor (i.e: the user can't configure it like on a SQL table) - page_size: Optional[int] = None # Use this variable to define page size for API http requests with pagination support + page_size: Optional[ + int + ] = None # Use this variable to define page size for API http requests with pagination support # TODO: remove legacy HttpAuthenticator authenticator references - def __init__(self, authenticator: Optional[Union[AuthBase, HttpAuthenticator]] = None, api_budget: Optional[APIBudget] = None): + def __init__( + self, + authenticator: Optional[Union[AuthBase, HttpAuthenticator]] = None, + api_budget: Optional[APIBudget] = None, + ): self._api_budget: APIBudget = api_budget or APIBudget(policies=[]) self._session = self.request_session() self._session.mount( - "https://", requests.adapters.HTTPAdapter(pool_connections=MAX_CONNECTION_POOL_SIZE, pool_maxsize=MAX_CONNECTION_POOL_SIZE) + "https://", + requests.adapters.HTTPAdapter( + pool_connections=MAX_CONNECTION_POOL_SIZE, + pool_maxsize=MAX_CONNECTION_POOL_SIZE, + ), ) self._authenticator: HttpAuthenticator = NoAuth() if isinstance(authenticator, AuthBase): @@ -52,22 +79,6 @@ def __init__(self, authenticator: Optional[Union[AuthBase, HttpAuthenticator]] = elif authenticator: self._authenticator = authenticator - @property - def cache_filename(self) -> str: - """ - Override if needed. Return the name of cache file - Note that if the environment variable REQUEST_CACHE_PATH is not set, the cache will be in-memory only. - """ - return f"{self.name}.sqlite" - - @property - def use_cache(self) -> bool: - """ - Override if needed. If True, all records will be cached. - Note that if the environment variable REQUEST_CACHE_PATH is not set, the cache will be in-memory only. - """ - return False - def request_session(self) -> requests.Session: """ Session factory based on use_cache property and call rate limits (api_budget parameter) @@ -92,48 +103,6 @@ def clear_cache(self) -> None: if isinstance(self._session, requests_cache.CachedSession): self._session.cache.clear() # type: ignore # cache.clear is not typed - @property - @abstractmethod - def url_base(self) -> str: - """ - :return: URL base for the API endpoint e.g: if you wanted to hit https://myapi.com/v1/some_entity then this should return "https://myapi.com/v1/" - """ - - @property - def http_method(self) -> str: - """ - Override if needed. See get_request_data/get_request_json if using POST/PUT/PATCH. - """ - return "GET" - - @property - def raise_on_http_errors(self) -> bool: - """ - Override if needed. If set to False, allows opting-out of raising HTTP code exception. - """ - return True - - @property - def max_retries(self) -> Union[int, None]: - """ - Override if needed. Specifies maximum amount of retries for backoff policy. Return None for no limit. - """ - return 5 - - @property - def max_time(self) -> Union[int, None]: - """ - Override if needed. Specifies maximum total waiting time (in seconds) for backoff policy. Return None for no limit. - """ - return 60 * 10 - - @property - def retry_factor(self) -> float: - """ - Override if needed. Specifies factor for backoff policy. - """ - return 5 - @property def authenticator(self) -> HttpAuthenticator: return self._authenticator @@ -143,7 +112,9 @@ def availability_strategy(self) -> Optional[AvailabilityStrategy]: return HttpAvailabilityStrategy() @abstractmethod - def next_page_token(self, response: requests.Response) -> Optional[Mapping[str, Any]]: + def next_page_token( + self, response: requests.Response + ) -> Optional[Mapping[str, Any]]: """ Override this method to define a pagination strategy. @@ -164,73 +135,6 @@ def path( Returns the URL path for the API endpoint e.g: if you wanted to hit https://myapi.com/v1/some_entity then this should return "some_entity" """ - def request_params( - self, - stream_state: Optional[Mapping[str, Any]], - stream_slice: Optional[Mapping[str, Any]] = None, - next_page_token: Optional[Mapping[str, Any]] = None, - ) -> MutableMapping[str, Any]: - """ - Override this method to define the query parameters that should be set on an outgoing HTTP request given the inputs. - - E.g: you might want to define query parameters for paging if next_page_token is not None. - """ - return {} - - def request_headers( - self, - stream_state: Optional[Mapping[str, Any]], - stream_slice: Optional[Mapping[str, Any]] = None, - next_page_token: Optional[Mapping[str, Any]] = None, - ) -> Mapping[str, Any]: - """ - Override to return any non-auth headers. Authentication headers will overwrite any overlapping headers returned from this method. - """ - return {} - - def request_body_data( - self, - stream_state: Optional[Mapping[str, Any]], - stream_slice: Optional[Mapping[str, Any]] = None, - next_page_token: Optional[Mapping[str, Any]] = None, - ) -> Optional[Union[Mapping[str, Any], str]]: - """ - Override when creating POST/PUT/PATCH requests to populate the body of the request with a non-JSON payload. - - If returns a ready text that it will be sent as is. - If returns a dict that it will be converted to a urlencoded form. - E.g. {"key1": "value1", "key2": "value2"} => "key1=value1&key2=value2" - - At the same time only one of the 'request_body_data' and 'request_body_json' functions can be overridden. - """ - return None - - def request_body_json( - self, - stream_state: Optional[Mapping[str, Any]], - stream_slice: Optional[Mapping[str, Any]] = None, - next_page_token: Optional[Mapping[str, Any]] = None, - ) -> Optional[Mapping[str, Any]]: - """ - Override when creating POST/PUT/PATCH requests to populate the body of the request with a JSON payload. - - At the same time only one of the 'request_body_data' and 'request_body_json' functions can be overridden. - """ - return None - - def request_kwargs( - self, - stream_state: Optional[Mapping[str, Any]], - stream_slice: Optional[Mapping[str, Any]] = None, - next_page_token: Optional[Mapping[str, Any]] = None, - ) -> Mapping[str, Any]: - """ - Override to return a mapping of keyword arguments to be used when creating the HTTP request. - Any option listed in https://docs.python-requests.org/en/latest/api/#requests.adapters.BaseAdapter.send for can be returned from - this method. Note that these options do not conflict with request-level options such as headers, request params, etc.. - """ - return {} - @abstractmethod def parse_response( self, @@ -284,24 +188,6 @@ def error_message(self, response: requests.Response) -> str: """ return "" - def must_deduplicate_query_params(self) -> bool: - return False - - def deduplicate_query_params(self, url: str, params: Optional[Mapping[str, Any]]) -> Mapping[str, Any]: - """ - Remove query parameters from params mapping if they are already encoded in the URL. - :param url: URL with - :param params: - :return: - """ - if params is None: - params = {} - query_string = urllib.parse.urlparse(url).query - query_dict = {k: v[0] for k, v in urllib.parse.parse_qs(query_string).items()} - - duplicate_keys_with_same_value = {k for k in query_dict.keys() if str(params.get(k)) == str(query_dict[k])} - return {k: v for k, v in params.items() if k not in duplicate_keys_with_same_value} - def _create_prepared_request( self, path: str, @@ -315,7 +201,12 @@ def _create_prepared_request( query_params = self.deduplicate_query_params(url, params) else: query_params = params or {} - args = {"method": self.http_method, "url": url, "headers": headers, "params": query_params} + args = { + "method": self.http_method, + "url": url, + "headers": headers, + "params": query_params, + } if self.http_method.upper() in BODY_REQUEST_METHODS: if json and data: raise RequestBodyException( @@ -325,15 +216,15 @@ def _create_prepared_request( args["json"] = json elif data: args["data"] = data - prepared_request: requests.PreparedRequest = self._session.prepare_request(requests.Request(**args)) + prepared_request: requests.PreparedRequest = self._session.prepare_request( + requests.Request(**args) + ) return prepared_request - @classmethod - def _join_url(cls, url_base: str, path: str) -> str: - return urljoin(url_base, path) - - def _send(self, request: requests.PreparedRequest, request_kwargs: Mapping[str, Any]) -> requests.Response: + def _send( + self, request: requests.PreparedRequest, request_kwargs: Mapping[str, Any] + ) -> requests.Response: """ Wraps sending the request in rate limit and error handlers. Please note that error handling for HTTP status codes will be ignored if raise_on_http_errors is set to False @@ -353,7 +244,12 @@ def _send(self, request: requests.PreparedRequest, request_kwargs: Mapping[str, Unexpected persistent exceptions are not handled and will cause the sync to fail. """ self.logger.debug( - "Making outbound API request", extra={"headers": request.headers, "url": request.url, "request_body": request.body} + "Making outbound API request", + extra={ + "headers": request.headers, + "url": request.url, + "request_body": request.body, + }, ) response: requests.Response = self._session.send(request, **request_kwargs) @@ -361,17 +257,27 @@ def _send(self, request: requests.PreparedRequest, request_kwargs: Mapping[str, # Do it only in debug mode if self.logger.isEnabledFor(logging.DEBUG): self.logger.debug( - "Receiving response", extra={"headers": response.headers, "status": response.status_code, "body": response.text} + "Receiving response", + extra={ + "headers": response.headers, + "status": response.status_code, + "body": response.text, + }, ) if self.should_retry(response): custom_backoff_time = self.backoff_time(response) error_message = self.error_message(response) if custom_backoff_time: raise UserDefinedBackoffException( - backoff=custom_backoff_time, request=request, response=response, error_message=error_message + backoff=custom_backoff_time, + request=request, + response=response, + error_message=error_message, ) else: - raise DefaultBackoffException(request=request, response=response, error_message=error_message) + raise DefaultBackoffException( + request=request, response=response, error_message=error_message + ) elif self.raise_on_http_errors: # Raise any HTTP exceptions that happened in case there were unexpected ones try: @@ -381,7 +287,9 @@ def _send(self, request: requests.PreparedRequest, request_kwargs: Mapping[str, raise exc return response - def _send_request(self, request: requests.PreparedRequest, request_kwargs: Mapping[str, Any]) -> requests.Response: + def _send_request( + self, request: requests.PreparedRequest, request_kwargs: Mapping[str, Any] + ) -> requests.Response: """ Creates backoff wrappers which are responsible for retry logic """ @@ -417,60 +325,20 @@ def _send_request(self, request: requests.PreparedRequest, request_kwargs: Mappi if max_tries is not None: max_tries = max(0, max_tries) + 1 - user_backoff_handler = user_defined_backoff_handler(max_tries=max_tries, max_time=max_time)(self._send) - backoff_handler = default_backoff_handler(max_tries=max_tries, max_time=max_time, factor=self.retry_factor) + user_backoff_handler = user_defined_backoff_handler( + max_tries=max_tries, max_time=max_time + )(self._send) + backoff_handler = default_backoff_handler( + max_tries=max_tries, max_time=max_time, factor=self.retry_factor + ) return backoff_handler(user_backoff_handler)(request, request_kwargs) @classmethod def parse_response_error_message(cls, response: requests.Response) -> Optional[str]: - """ - Parses the raw response object from a failed request into a user-friendly error message. - By default, this method tries to grab the error message from JSON responses by following common API patterns. Override to parse differently. + return HttpError.parse_response_error_message(response) - :param response: - :return: A user-friendly message that indicates the cause of the error - """ - - # default logic to grab error from common fields - def _try_get_error(value: Optional[JsonType]) -> Optional[str]: - if isinstance(value, str): - return value - elif isinstance(value, list): - errors_in_value = [_try_get_error(v) for v in value] - return ", ".join(v for v in errors_in_value if v is not None) - elif isinstance(value, dict): - new_value = ( - value.get("message") - or value.get("messages") - or value.get("error") - or value.get("errors") - or value.get("failures") - or value.get("failure") - or value.get("detail") - ) - return _try_get_error(new_value) - return None - - try: - body = response.json() - return _try_get_error(body) - except requests.exceptions.JSONDecodeError: - return None - - def get_error_display_message(self, exception: BaseException) -> Optional[str]: - """ - Retrieves the user-friendly display message that corresponds to an exception. - This will be called when encountering an exception while reading records from the stream, and used to build the AirbyteTraceMessage. - - The default implementation of this method only handles HTTPErrors by passing the response to self.parse_response_error_message(). - The method should be overriden as needed to handle any additional exception types. - - :param exception: The exception that was raised - :return: A user-friendly message that indicates the cause of the error - """ - if isinstance(exception, requests.HTTPError) and exception.response is not None: - return self.parse_response_error_message(exception.response) - return None + def parse_error_message(cls, error: HttpError) -> Optional[str]: + return HttpError.parse_response_error_message(error.response) def read_records( self, @@ -480,13 +348,23 @@ def read_records( stream_state: Optional[Mapping[str, Any]] = None, ) -> Iterable[StreamData]: yield from self._read_pages( - lambda req, res, state, _slice: self.parse_response(res, stream_slice=_slice, stream_state=state), stream_slice, stream_state + lambda req, res, state, _slice: self.parse_response( + res, stream_slice=_slice, stream_state=state + ), + stream_slice, + stream_state, ) def _read_pages( self, records_generator_fn: Callable[ - [requests.PreparedRequest, requests.Response, Mapping[str, Any], Optional[Mapping[str, Any]]], Iterable[StreamData] + [ + requests.PreparedRequest, + requests.Response, + Mapping[str, Any], + Optional[Mapping[str, Any]], + ], + Iterable[StreamData], ], stream_slice: Optional[Mapping[str, Any]] = None, stream_state: Optional[Mapping[str, Any]] = None, @@ -495,8 +373,12 @@ def _read_pages( pagination_complete = False next_page_token = None while not pagination_complete: - request, response = self._fetch_next_page(stream_slice, stream_state, next_page_token) - yield from records_generator_fn(request, response, stream_state, stream_slice) + request, response = self._fetch_next_page( + stream_slice, stream_state, next_page_token + ) + yield from records_generator_fn( + request, response, stream_state, stream_slice + ) next_page_token = self.next_page_token(response) if not next_page_token: @@ -511,15 +393,39 @@ def _fetch_next_page( stream_state: Optional[Mapping[str, Any]] = None, next_page_token: Optional[Mapping[str, Any]] = None, ) -> Tuple[requests.PreparedRequest, requests.Response]: - request_headers = self.request_headers(stream_state=stream_state, stream_slice=stream_slice, next_page_token=next_page_token) + request_headers = self.request_headers( + stream_state=stream_state, + stream_slice=stream_slice, + next_page_token=next_page_token, + ) request = self._create_prepared_request( - path=self.path(stream_state=stream_state, stream_slice=stream_slice, next_page_token=next_page_token), + path=self.path( + stream_state=stream_state, + stream_slice=stream_slice, + next_page_token=next_page_token, + ), headers=dict(request_headers, **self.authenticator.get_auth_header()), - params=self.request_params(stream_state=stream_state, stream_slice=stream_slice, next_page_token=next_page_token), - json=self.request_body_json(stream_state=stream_state, stream_slice=stream_slice, next_page_token=next_page_token), - data=self.request_body_data(stream_state=stream_state, stream_slice=stream_slice, next_page_token=next_page_token), + params=self.request_params( + stream_state=stream_state, + stream_slice=stream_slice, + next_page_token=next_page_token, + ), + json=self.request_body_json( + stream_state=stream_state, + stream_slice=stream_slice, + next_page_token=next_page_token, + ), + data=self.request_body_data( + stream_state=stream_state, + stream_slice=stream_slice, + next_page_token=next_page_token, + ), + ) + request_kwargs = self.request_kwargs( + stream_state=stream_state, + stream_slice=stream_slice, + next_page_token=next_page_token, ) - request_kwargs = self.request_kwargs(stream_state=stream_state, stream_slice=stream_slice, next_page_token=next_page_token) response = self._send_request(request, request_kwargs) return request, response @@ -534,16 +440,24 @@ def __init__(self, parent: HttpStream, **kwargs: Any): self.parent = parent def stream_slices( - self, sync_mode: SyncMode, cursor_field: Optional[List[str]] = None, stream_state: Optional[Mapping[str, Any]] = None + self, + sync_mode: SyncMode, + cursor_field: Optional[List[str]] = None, + stream_state: Optional[Mapping[str, Any]] = None, ) -> Iterable[Optional[Mapping[str, Any]]]: parent_stream_slices = self.parent.stream_slices( - sync_mode=SyncMode.full_refresh, cursor_field=cursor_field, stream_state=stream_state + sync_mode=SyncMode.full_refresh, + cursor_field=cursor_field, + stream_state=stream_state, ) # iterate over all parent stream_slices for stream_slice in parent_stream_slices: parent_records = self.parent.read_records( - sync_mode=SyncMode.full_refresh, cursor_field=cursor_field, stream_slice=stream_slice, stream_state=stream_state + sync_mode=SyncMode.full_refresh, + cursor_field=cursor_field, + stream_slice=stream_slice, + stream_state=stream_state, ) # iterate over all parent records with current stream_slice diff --git a/airbyte-cdk/python/airbyte_cdk/sources/streams/http/http_base.py b/airbyte-cdk/python/airbyte_cdk/sources/streams/http/http_base.py new file mode 100644 index 000000000000..56cd77f3384f --- /dev/null +++ b/airbyte-cdk/python/airbyte_cdk/sources/streams/http/http_base.py @@ -0,0 +1,322 @@ +# +# Copyright (c) 2023 Airbyte, Inc., all rights reserved. +# + +import urllib +from abc import ABC, abstractmethod +from typing import Any, Iterable, List, Mapping, MutableMapping, Optional, Union +from urllib.parse import urljoin + +import requests +from airbyte_cdk.models import SyncMode +from airbyte_cdk.sources.streams.core import Stream, StreamData +from airbyte_cdk.sources.streams.http.utils import HttpError + +from .auth.core import HttpAuthenticator + +# list of all possible HTTP methods which can be used for sending of request bodies +BODY_REQUEST_METHODS = ("GET", "POST", "PUT", "PATCH") + + +class BaseHttpStream(Stream, ABC): + """ + Base abstract class for an Airbyte Stream using the HTTP protocol. Basic building block for users building an Airbyte source for a HTTP API. + """ + + page_size: Optional[ + int + ] = None # Use this variable to define page size for API http requests with pagination support + + @property + def source_defined_cursor(self) -> bool: + return True + + @property + def cache_filename(self) -> str: + """ + Override if needed. Return the name of cache file + Note that if the environment variable REQUEST_CACHE_PATH is not set, the cache will be in-memory only. + """ + return f"{self.name}.sqlite" + + @property + def use_cache(self) -> bool: + """ + Override if needed. If True, all records will be cached. + Note that if the environment variable REQUEST_CACHE_PATH is not set, the cache will be in-memory only. + """ + return False + + @abstractmethod + def request_session(self) -> requests.Session: + """ + Session factory based on use_cache property and call rate limits (api_budget parameter) + :return: instance of request-based session + """ + + @abstractmethod + def clear_cache(self) -> None: + """ + Clear cached requests for current session, can be called any time + """ + + @property + @abstractmethod + def url_base(self) -> str: + """ + :return: URL base for the API endpoint e.g: if you wanted to hit https://myapi.com/v1/some_entity then this should return "https://myapi.com/v1/" + """ + + @property + def http_method(self) -> str: + """ + Override if needed. See get_request_data/get_request_json if using POST/PUT/PATCH. + """ + return "GET" + + @property + def raise_on_http_errors(self) -> bool: + """ + Override if needed. If set to False, allows opting-out of raising HTTP code exception. + """ + return True + + @property + def max_retries(self) -> Union[int, None]: + """ + Override if needed. Specifies maximum amount of retries for backoff policy. Return None for no limit. + """ + return 5 + + @property + def max_time(self) -> Union[int, None]: + """ + Override if needed. Specifies maximum total waiting time (in seconds) for backoff policy. Return None for no limit. + """ + return 60 * 10 + + @property + def retry_factor(self) -> float: + """ + Override if needed. Specifies factor for backoff policy. + """ + return 5 + + @property + @abstractmethod + def authenticator(self) -> HttpAuthenticator: + ... + + @abstractmethod + def next_page_token( + self, response: requests.Response + ) -> Optional[Mapping[str, Any]]: + """ + Override this method to define a pagination strategy. + + The value returned from this method is passed to most other methods in this class. Use it to form a request e.g: set headers or query params. + + :return: The token for the next page from the input response object. Returning None means there are no more pages to read in this response. + """ + + @abstractmethod + def path( + self, + *, + stream_state: Optional[Mapping[str, Any]] = None, + stream_slice: Optional[Mapping[str, Any]] = None, + next_page_token: Optional[Mapping[str, Any]] = None, + ) -> str: + """ + Returns the URL path for the API endpoint e.g: if you wanted to hit https://myapi.com/v1/some_entity then this should return "some_entity" + """ + + def request_params( + self, + stream_state: Optional[Mapping[str, Any]], + stream_slice: Optional[Mapping[str, Any]] = None, + next_page_token: Optional[Mapping[str, Any]] = None, + ) -> MutableMapping[str, Any]: + """ + Override this method to define the query parameters that should be set on an outgoing HTTP request given the inputs. + + E.g: you might want to define query parameters for paging if next_page_token is not None. + """ + return {} + + def request_headers( + self, + stream_state: Optional[Mapping[str, Any]], + stream_slice: Optional[Mapping[str, Any]] = None, + next_page_token: Optional[Mapping[str, Any]] = None, + ) -> Mapping[str, Any]: + """ + Override to return any non-auth headers. Authentication headers will overwrite any overlapping headers returned from this method. + """ + return {} + + def request_body_data( + self, + stream_state: Optional[Mapping[str, Any]], + stream_slice: Optional[Mapping[str, Any]] = None, + next_page_token: Optional[Mapping[str, Any]] = None, + ) -> Optional[Union[Mapping[str, Any], str]]: + """ + Override when creating POST/PUT/PATCH requests to populate the body of the request with a non-JSON payload. + + If returns a ready text that it will be sent as is. + If returns a dict that it will be converted to a urlencoded form. + E.g. {"key1": "value1", "key2": "value2"} => "key1=value1&key2=value2" + + At the same time only one of the 'request_body_data' and 'request_body_json' functions can be overridden. + """ + return None + + def request_body_json( + self, + stream_state: Optional[Mapping[str, Any]], + stream_slice: Optional[Mapping[str, Any]] = None, + next_page_token: Optional[Mapping[str, Any]] = None, + ) -> Optional[Mapping[str, Any]]: + """ + Override when creating POST/PUT/PATCH requests to populate the body of the request with a JSON payload. + + At the same time only one of the 'request_body_data' and 'request_body_json' functions can be overridden. + """ + return None + + def request_kwargs( + self, + stream_state: Optional[Mapping[str, Any]], + stream_slice: Optional[Mapping[str, Any]] = None, + next_page_token: Optional[Mapping[str, Any]] = None, + ) -> Mapping[str, Any]: + """ + Override to return a mapping of keyword arguments to be used when creating the HTTP request. + Any option listed in https://docs.python-requests.org/en/latest/api/#requests.adapters.BaseAdapter.send for can be returned from + this method. Note that these options do not conflict with request-level options such as headers, request params, etc.. + """ + return {} + + @abstractmethod + def parse_response( + self, + response: requests.Response, + *, + stream_state: Mapping[str, Any], + stream_slice: Optional[Mapping[str, Any]] = None, + next_page_token: Optional[Mapping[str, Any]] = None, + ) -> Iterable[Mapping[str, Any]]: + """ + Parses the raw response object into a list of records. + By default, this returns an iterable containing the input. Override to parse differently. + :param response: + :param stream_state: + :param stream_slice: + :param next_page_token: + :return: An iterable containing the parsed response + """ + + @abstractmethod + def should_retry(self, response: requests.Response) -> bool: + """ + Override to set different conditions for backoff based on the response from the server. + + By default, back off on the following HTTP response statuses: + - 429 (Too Many Requests) indicating rate limiting + - 500s to handle transient server errors + + Unexpected but transient exceptions (connection timeout, DNS resolution failed, etc..) are retried by default. + """ + + def backoff_time(self, response: requests.Response) -> Optional[float]: + """ + Override this method to dynamically determine backoff time e.g: by reading the X-Retry-After header. + + This method is called only if should_backoff() returns True for the input request. + + :param response: + :return how long to backoff in seconds. The return value may be a floating point number for subsecond precision. Returning None defers backoff + to the default backoff behavior (e.g using an exponential algorithm). + """ + return None + + def error_message(self, response: requests.Response) -> str: + """ + Override this method to specify a custom error message which can incorporate the HTTP response received + + :param response: The incoming HTTP response from the partner API + :return: + """ + return "" + + def must_deduplicate_query_params(self) -> bool: + return False + + def deduplicate_query_params( + self, url: str, params: Optional[Mapping[str, Any]] + ) -> Mapping[str, Any]: + """ + Remove query parameters from params mapping if they are already encoded in the URL. + :param url: URL with + :param params: + :return: + """ + if params is None: + params = {} + query_string = urllib.parse.urlparse(url).query + query_dict = {k: v[0] for k, v in urllib.parse.parse_qs(query_string).items()} + + duplicate_keys_with_same_value = { + k for k in query_dict.keys() if str(params.get(k)) == str(query_dict[k]) + } + return { + k: v for k, v in params.items() if k not in duplicate_keys_with_same_value + } + + @classmethod + def _join_url(cls, url_base: str, path: str) -> str: + return urljoin(url_base, path) + + @classmethod + @abstractmethod + def parse_response_error_message(cls, response: requests.Response) -> Optional[str]: + """ + Parses the raw response object from a failed request into a user-friendly error message. + By default, this method tries to grab the error message from JSON responses by following common API patterns. Override to parse differently. + + :param response: + :return: A user-friendly message that indicates the cause of the error + """ + + @classmethod + def parse_error_message(cls, error: HttpError) -> Optional[str]: + return error.parse_error_message() + + def get_error_display_message(self, exception: BaseException) -> Optional[str]: + """ + Retrieves the user-friendly display message that corresponds to an exception. + This will be called when encountering an exception while reading records from the stream, and used to build the AirbyteTraceMessage. + + The default implementation of this method only handles HTTPErrors by passing the response to self.parse_response_error_message(). + The method should be overriden as needed to handle any additional exception types. + + :param exception: The exception that was raised + :return: A user-friendly message that indicates the cause of the error + """ + if isinstance(exception, HttpError): + return self.parse_error_message(exception) + elif isinstance(exception, requests.HTTPError) and exception.response is not None: + # TODO: wrap synchronous codepath's errors in HttpError to delete this path + return self.parse_response_error_message(exception.response) + return None + + @abstractmethod + def read_records( + self, + sync_mode: SyncMode, + cursor_field: Optional[List[str]] = None, + stream_slice: Optional[Mapping[str, Any]] = None, + stream_state: Optional[Mapping[str, Any]] = None, + ) -> Iterable[StreamData]: + ... diff --git a/airbyte-cdk/python/airbyte_cdk/sources/streams/http/rate_limiting.py b/airbyte-cdk/python/airbyte_cdk/sources/streams/http/rate_limiting.py index 84d320345294..51dcf789783f 100644 --- a/airbyte-cdk/python/airbyte_cdk/sources/streams/http/rate_limiting.py +++ b/airbyte-cdk/python/airbyte_cdk/sources/streams/http/rate_limiting.py @@ -5,15 +5,23 @@ import logging import sys import time -from typing import Any, Callable, Mapping, Optional +from typing import Any, Callable, Coroutine, Mapping, Optional, Type, Union +import aiohttp import backoff -from requests import PreparedRequest, RequestException, Response, codes, exceptions +from requests import HTTPError, PreparedRequest, RequestException, Response, codes, exceptions -from .exceptions import DefaultBackoffException, UserDefinedBackoffException +from airbyte_cdk.sources.async_cdk.streams.http.exceptions_async import AsyncDefaultBackoffException, AsyncUserDefinedBackoffException +from airbyte_cdk.sources.streams.http.utils import HttpError +from .exceptions import AbstractBaseBackoffException, DefaultBackoffException, UserDefinedBackoffException TRANSIENT_EXCEPTIONS = ( DefaultBackoffException, + AsyncDefaultBackoffException, + aiohttp.ClientPayloadError, + aiohttp.ServerTimeoutError, + aiohttp.ServerConnectionError, + aiohttp.ServerDisconnectedError, exceptions.ConnectTimeout, exceptions.ReadTimeout, exceptions.ConnectionError, @@ -23,30 +31,58 @@ logger = logging.getLogger("airbyte") -SendRequestCallableType = Callable[[PreparedRequest, Mapping[str, Any]], Response] +AioHttpCallableType = Callable[ + [aiohttp.ClientRequest, Mapping[str, Any]], + Coroutine[Any, Any, aiohttp.ClientResponse], +] +RequestsCallableType = Callable[[PreparedRequest, Mapping[str, Any]], Response] +SendRequestCallableType = Union[AioHttpCallableType, RequestsCallableType] def default_backoff_handler( - max_tries: Optional[int], factor: float, max_time: Optional[int] = None, **kwargs: Any + max_tries: Optional[int], + factor: float, + max_time: Optional[int] = None, + **kwargs: Any, ) -> Callable[[SendRequestCallableType], SendRequestCallableType]: def log_retry_attempt(details: Mapping[str, Any]) -> None: _, exc, _ = sys.exc_info() - if isinstance(exc, RequestException) and exc.response: - logger.info(f"Status code: {exc.response.status_code}, Response Content: {exc.response.content}") - logger.info( - f"Caught retryable error '{str(exc)}' after {details['tries']} tries. Waiting {details['wait']} seconds then retrying..." - ) + if isinstance(exc, HttpError): + logger.info(f"Status code: {exc.status_code}, Response Content: {exc.content}") + logger.info( + f"Caught retryable error '{exc.message}' after {details['tries']} tries. Waiting {details['wait']} seconds then retrying..." + ) + + if isinstance(exc, RequestException): + exc = HttpError(requests_error=exc) + + logger.info( + f"Caught retryable error '{str(exc)}' after {details['tries']} tries. Waiting {details['wait']} seconds then retrying..." + ) def should_give_up(exc: Exception) -> bool: # If a non-rate-limiting related 4XX error makes it this far, it means it was unexpected and probably consistent, so we shouldn't back off - if isinstance(exc, RequestException): + if isinstance(exc, HttpError): + give_up: bool = ( + exc.status_code != codes.too_many_requests and 400 <= exc.status_code < 500 + ) + status_code = exc.status_code + + elif isinstance(exc, RequestException): + # TODO: wrap synchronous codepath's errors in HttpError to delete this path give_up: bool = ( exc.response is not None and exc.response.status_code != codes.too_many_requests and 400 <= exc.response.status_code < 500 ) - if give_up: - logger.info(f"Giving up for returned HTTP status: {exc.response.status_code}") - return give_up - # Only RequestExceptions are retryable, so if we get here, it's not retryable + status_code = exc.response if exc else None + + else: + status_code = None + give_up = True + + if give_up: + logger.info(f"Giving up for returned HTTP status: {status_code}") + + # Only RequestExceptions and HttpExceptions are retryable, so if we get here, it's not retryable return False return backoff.on_exception( @@ -62,14 +98,17 @@ def should_give_up(exc: Exception) -> bool: ) -def user_defined_backoff_handler( - max_tries: Optional[int], max_time: Optional[int] = None, **kwargs: Any +def _make_user_defined_backoff_handler( + exc_type: Type[AbstractBaseBackoffException], max_tries: Optional[int], max_time: Optional[int] = None, **kwargs: Any ) -> Callable[[SendRequestCallableType], SendRequestCallableType]: def sleep_on_ratelimit(details: Mapping[str, Any]) -> None: _, exc, _ = sys.exc_info() - if isinstance(exc, UserDefinedBackoffException): - if exc.response: + if isinstance(exc, exc_type): + if isinstance(exc, HttpError): + logger.info(f"Status code: {exc.status_code}, Response Content: {exc.content}") + elif exc.response: logger.info(f"Status code: {exc.response.status_code}, Response Content: {exc.response.content}") + retry_after = exc.backoff logger.info(f"Retrying. Sleeping for {retry_after} seconds") time.sleep(retry_after + 1) # extra second to cover any fractions of second @@ -77,13 +116,17 @@ def sleep_on_ratelimit(details: Mapping[str, Any]) -> None: def log_give_up(details: Mapping[str, Any]) -> None: _, exc, _ = sys.exc_info() if isinstance(exc, RequestException): - logger.error(f"Max retry limit reached. Request: {exc.request}, Response: {exc.response}") + exc = HttpError(requests_error=exc) + if isinstance(exc, (HTTPError, HttpError)): + logger.error( + f"Max retry limit reached. Request: {exc.request}, Response: {exc.response}" + ) else: logger.error("Max retry limit reached for unknown request and response") return backoff.on_exception( backoff.constant, - UserDefinedBackoffException, + exc_type, interval=0, # skip waiting, we'll wait in on_backoff handler on_backoff=sleep_on_ratelimit, on_giveup=log_give_up, @@ -92,3 +135,6 @@ def log_give_up(details: Mapping[str, Any]) -> None: max_time=max_time, **kwargs, ) + +user_defined_backoff_handler = lambda *args, **kwargs: _make_user_defined_backoff_handler(UserDefinedBackoffException, *args, **kwargs) +async_user_defined_backoff_handler = lambda *args, **kwargs: _make_user_defined_backoff_handler(AsyncUserDefinedBackoffException, *args, **kwargs) diff --git a/airbyte-cdk/python/airbyte_cdk/sources/streams/http/utils.py b/airbyte-cdk/python/airbyte_cdk/sources/streams/http/utils.py new file mode 100644 index 000000000000..b9bbc5d2e2b1 --- /dev/null +++ b/airbyte-cdk/python/airbyte_cdk/sources/streams/http/utils.py @@ -0,0 +1,166 @@ +import json +from typing import Optional, Union + +import aiohttp +import requests + +from airbyte_cdk.sources.utils.types import JsonType + + +class HttpError(Exception): + def __init__( + self, + requests_error: Optional[requests.RequestException] = None, + aiohttp_error: Optional[aiohttp.ClientResponseError] = None, + error_message: Optional[str] = None, + ): + assert ( + requests_error or aiohttp_error and not (requests_error and aiohttp_error) + ), "requests_error xor aiohttp_error must be supplied" + self._requests_error = requests_error + self._aiohttp_error = aiohttp_error + self._aiohttp_response_json = None + self._aiohttp_response_content = None + self._aiohttp_response_text = None + self._aiohttp_response = None + self.error_message = error_message + + @property + def status_code(self) -> Optional[int]: + if self._requests_error and self._requests_error.response: + return self._requests_error.response.status_code + elif self._aiohttp_error: + return self._aiohttp_error.status + return 0 + + @property + def message(self) -> str: + if self._requests_error: + return str(self._requests_error) + elif self._aiohttp_error: + return self.error_message + else: + return "" + + @property + def content(self) -> Optional[bytes]: + if self._requests_error and self._requests_error.response: + return self._requests_error.response.content + elif self._aiohttp_error: + return self._aiohttp_response_content + return b"" + + @property + def text(self) -> Optional[str]: + if self._requests_error and self._requests_error.response: + return self._requests_error.response.text + elif self._aiohttp_error: + return self._aiohttp_response_text + return "" + + def json(self) -> Optional[JsonType]: + if self._requests_error and self._requests_error.response: + return self._requests_error.response.json() + elif self._aiohttp_error: + return self._aiohttp_response_json + return "" + + @property + def request(self) -> Optional[Union[requests.Request, aiohttp.RequestInfo]]: + if self._requests_error and self._requests_error.response: + return self._requests_error.request + elif self._aiohttp_error: + return self._aiohttp_error.request_info + + @property + def response(self) -> Optional[Union[requests.Response, aiohttp.ClientResponse]]: + if self._requests_error and self._requests_error.response: + return self._requests_error.response + elif self._aiohttp_error: + return self._aiohttp_response + + @property + def url(self) -> str: + if self._requests_error and self._requests_error.request: + return self._requests_error.request.url or "" + elif self._aiohttp_error: + return str(self._aiohttp_error.request_info.url) + return "" + + @property + def reason(self) -> Optional[str]: + if self._requests_error and self._requests_error.request: + return self._requests_error.response.reason + elif self._aiohttp_error: + return self._aiohttp_error.message + return "" + + @classmethod + def parse_response_error_message(cls, response: requests.Response) -> Optional[str]: + """ + Parses the raw response object from a failed request into a user-friendly error message. + By default, this method tries to grab the error message from JSON responses by following common API patterns. Override to parse differently. + + :param response: + :return: A user-friendly message that indicates the cause of the error + """ + try: + return cls._try_get_error(response.json()) + except requests.exceptions.JSONDecodeError: + return None + + def parse_error_message(self) -> Optional[str]: + """ + Parses the raw response object from a failed request into a user-friendly error message. + By default, this method tries to grab the error message from JSON responses by following common API patterns. Override to parse differently. + + :param response: + :return: A user-friendly message that indicates the cause of the error + """ + if self._requests_error and self._requests_error.response: + return self.parse_response_error_message(self._requests_error.response) + elif self._aiohttp_error: + try: + return self._try_get_error(self._aiohttp_response_json) + except requests.exceptions.JSONDecodeError: + return None + return None + + @classmethod + def _try_get_error(cls, value: Optional[JsonType]) -> Optional[str]: + # default logic to grab error from common fields + if isinstance(value, str): + return value + elif isinstance(value, list): + errors_in_value = [cls._try_get_error(v) for v in value] + return ", ".join(v for v in errors_in_value if v is not None) + elif isinstance(value, dict): + new_value = ( + value.get("message") + or value.get("messages") + or value.get("error") + or value.get("errors") + or value.get("failures") + or value.get("failure") + or value.get("detail") + ) + return cls._try_get_error(new_value) + return None + + # Async utils + + async def set_response_data(self, response: aiohttp.ClientResponse): + try: + response_json = await response.json() + except (json.JSONDecodeError, aiohttp.ContentTypeError): + response_json = None + except Exception as exc: + raise NotImplementedError(f"Unexpected!!!!!!!! {exc}") # TODO + self.logger.error(f"Unable to get error json from response: {exc}") + response_json = None + + text = await response.text() # This fixed a test + self._aiohttp_response = response + self._aiohttp_response_json = response_json or text + self._aiohttp_response_content = await response.content.read() + self._aiohttp_response_text = text diff --git a/airbyte-cdk/python/setup.py b/airbyte-cdk/python/setup.py index a5cac26e35fd..646633e2713e 100644 --- a/airbyte-cdk/python/setup.py +++ b/airbyte-cdk/python/setup.py @@ -65,6 +65,9 @@ packages=find_packages(exclude=("unit_tests",)), package_data={"airbyte_cdk": ["py.typed", "sources/declarative/declarative_component_schema.yaml"]}, install_requires=[ + "aiohttp~=3.8.6" + "aiohttp-client-cache[aiosqlite]", + "aiosqlite", "airbyte-protocol-models==0.5.1", "backoff", "dpath~=2.0.1", @@ -92,6 +95,7 @@ "freezegun", "mypy", "pytest", + "pytest-asyncio", "pytest-cov", "pytest-mock", "requests-mock", diff --git a/airbyte-cdk/python/unit_tests/sources/async_cdk/__init__.py b/airbyte-cdk/python/unit_tests/sources/async_cdk/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/airbyte-cdk/python/unit_tests/sources/async_cdk/scenario_based/__init__.py b/airbyte-cdk/python/unit_tests/sources/async_cdk/scenario_based/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/airbyte-cdk/python/unit_tests/sources/async_cdk/scenario_based/async_concurrent_stream_scenarios.py b/airbyte-cdk/python/unit_tests/sources/async_cdk/scenario_based/async_concurrent_stream_scenarios.py new file mode 100644 index 000000000000..cb3fa36a7a8f --- /dev/null +++ b/airbyte-cdk/python/unit_tests/sources/async_cdk/scenario_based/async_concurrent_stream_scenarios.py @@ -0,0 +1,341 @@ +# +# Copyright (c) 2023 Airbyte, Inc., all rights reserved. +# + +from airbyte_cdk.sources.message import InMemoryMessageRepository +from unit_tests.sources.scenario_based.scenario_builder import TestScenarioBuilder +from unit_tests.sources.async_cdk.scenario_based.async_concurrent_stream_source_builder import ( + AlwaysAvailableAvailabilityStrategy, + ConcurrentSourceBuilder, + LocalAsyncStream, +) + +_id_only_stream = LocalAsyncStream( + name="stream1", + json_schema={ + "type": "object", + "properties": { + "id": {"type": ["null", "string"]}, + }, + }, + availability_strategy=AlwaysAvailableAvailabilityStrategy(), + primary_key=[], + cursor_field=None, + slices=[[{"id": "1"}, {"id": "2"}]], +) + +_id_only_stream_with_primary_key = LocalAsyncStream( + name="stream1", + json_schema={ + "type": "object", + "properties": { + "id": {"type": ["null", "string"]}, + }, + }, + availability_strategy=AlwaysAvailableAvailabilityStrategy(), + primary_key=["id"], + cursor_field=None, + slices=[[{"id": "1"}, {"id": "2"}]], +) + +_id_only_stream_multiple_partitions = LocalAsyncStream( + name="stream1", + json_schema={ + "type": "object", + "properties": { + "id": {"type": ["null", "string"]}, + }, + }, + availability_strategy=AlwaysAvailableAvailabilityStrategy(), + primary_key=[], + cursor_field=None, + slices=[[{"id": "1"}, {"id": "2"}], [{"id": "3"}, {"id": "4"}]], +) + +_id_only_stream_multiple_partitions_concurrency_level_two = LocalAsyncStream( # TODO: allow concurrency level to be set + name="stream1", + json_schema={ + "type": "object", + "properties": { + "id": {"type": ["null", "string"]}, + }, + }, + availability_strategy=AlwaysAvailableAvailabilityStrategy(), + primary_key=[], + cursor_field=None, + slices=[[{"id": "1"}, {"id": "2"}], [{"id": "3"}, {"id": "4"}]], +) + +_stream_raising_exception = LocalAsyncStream( + name="stream1", + json_schema={ + "type": "object", + "properties": { + "id": {"type": ["null", "string"]}, + }, + }, + availability_strategy=AlwaysAvailableAvailabilityStrategy(), + primary_key=[], + cursor_field=None, + slices=[[{"id": "1"}, ValueError("test exception")]], +) + +test_concurrent_cdk_single_stream = ( + TestScenarioBuilder() + .set_name("test_concurrent_cdk_single_stream") + .set_config({}) + .set_source_builder( + ConcurrentSourceBuilder() + .set_streams( + [ + _id_only_stream, + ] + ) + .set_message_repository(InMemoryMessageRepository()) + ) + .set_expected_records( + [ + {"data": {"id": "1"}, "stream": "stream1"}, + {"data": {"id": "2"}, "stream": "stream1"}, + ] + ) + .set_expected_catalog( + { + "streams": [ + { + "json_schema": { + "type": "object", + "properties": { + "id": {"type": ["null", "string"]}, + }, + }, + "name": "stream1", + "supported_sync_modes": ["full_refresh"], + } + ] + } + ) + .build() +) + +test_concurrent_cdk_single_stream_with_primary_key = ( + TestScenarioBuilder() + .set_name("test_concurrent_cdk_single_stream_with_primary_key") + .set_config({}) + .set_source_builder( + ConcurrentSourceBuilder() + .set_streams( + [ + _id_only_stream_with_primary_key, + ] + ) + .set_message_repository(InMemoryMessageRepository()) + ) + .set_expected_records( + [ + {"data": {"id": "1"}, "stream": "stream1"}, + {"data": {"id": "2"}, "stream": "stream1"}, + ] + ) + .set_expected_catalog( + { + "streams": [ + { + "json_schema": { + "type": "object", + "properties": { + "id": {"type": ["null", "string"]}, + }, + }, + "name": "stream1", + "supported_sync_modes": ["full_refresh"], + "source_defined_primary_key": [["id"]], + } + ] + } + ) + .build() +) + +test_concurrent_cdk_multiple_streams = ( + TestScenarioBuilder() + .set_name("test_concurrent_cdk_multiple_streams") + .set_config({}) + .set_source_builder( + ConcurrentSourceBuilder() + .set_streams( + [ + _id_only_stream, + LocalAsyncStream( + name="stream2", + json_schema={ + "type": "object", + "properties": { + "id": {"type": ["null", "string"]}, + "key": {"type": ["null", "string"]}, + }, + }, + availability_strategy=AlwaysAvailableAvailabilityStrategy(), + primary_key=[], + cursor_field=None, + slices=[[{"id": "10", "key": "v1"}, {"id": "20", "key": "v2"}]], + ), + ] + ) + .set_message_repository(InMemoryMessageRepository()) + ) + .set_expected_records( + [ + {"data": {"id": "1"}, "stream": "stream1"}, + {"data": {"id": "2"}, "stream": "stream1"}, + {"data": {"id": "10", "key": "v1"}, "stream": "stream2"}, + {"data": {"id": "20", "key": "v2"}, "stream": "stream2"}, + ] + ) + .set_expected_catalog( + { + "streams": [ + { + "json_schema": { + "type": "object", + "properties": { + "id": {"type": ["null", "string"]}, + }, + }, + "name": "stream1", + "supported_sync_modes": ["full_refresh"], + }, + { + "json_schema": { + "type": "object", + "properties": { + "id": {"type": ["null", "string"]}, + "key": {"type": ["null", "string"]}, + }, + }, + "name": "stream2", + "supported_sync_modes": ["full_refresh"], + }, + ] + } + ) + .build() +) + +test_concurrent_cdk_partition_raises_exception = ( + TestScenarioBuilder() + .set_name("test_concurrent_partition_raises_exception") + .set_config({}) + .set_source_builder( + ConcurrentSourceBuilder() + .set_streams( + [ + _stream_raising_exception, + ] + ) + .set_message_repository(InMemoryMessageRepository()) + ) + .set_expected_records( + [ + {"data": {"id": "1"}, "stream": "stream1"}, + ] + ) + .set_expected_read_error(ValueError, "test exception") + .set_expected_catalog( + { + "streams": [ + { + "json_schema": { + "type": "object", + "properties": { + "id": {"type": ["null", "string"]}, + }, + }, + "name": "stream1", + "supported_sync_modes": ["full_refresh"], + } + ] + } + ) + .build() +) + +test_concurrent_cdk_single_stream_multiple_partitions = ( + TestScenarioBuilder() + .set_name("test_concurrent_cdk_single_stream_multiple_partitions") + .set_config({}) + .set_source_builder( + ConcurrentSourceBuilder() + .set_streams( + [ + _id_only_stream_multiple_partitions, + ] + ) + .set_message_repository(InMemoryMessageRepository()) + ) + .set_expected_records( + [ + {"data": {"id": "1"}, "stream": "stream1"}, + {"data": {"id": "2"}, "stream": "stream1"}, + {"data": {"id": "3"}, "stream": "stream1"}, + {"data": {"id": "4"}, "stream": "stream1"}, + ] + ) + .set_expected_catalog( + { + "streams": [ + { + "json_schema": { + "type": "object", + "properties": { + "id": {"type": ["null", "string"]}, + }, + }, + "name": "stream1", + "supported_sync_modes": ["full_refresh"], + } + ] + } + ) + .build() +) + +test_concurrent_cdk_single_stream_multiple_partitions_concurrency_level_two = ( + TestScenarioBuilder() + .set_name("test_concurrent_cdk_single_stream_multiple_partitions_concurrency_level_2") + .set_config({}) + .set_source_builder( + ConcurrentSourceBuilder() + .set_streams( + [ + _id_only_stream_multiple_partitions_concurrency_level_two, + ] + ) + .set_message_repository(InMemoryMessageRepository()) + ) + .set_expected_records( + [ + {"data": {"id": "1"}, "stream": "stream1"}, + {"data": {"id": "2"}, "stream": "stream1"}, + {"data": {"id": "3"}, "stream": "stream1"}, + {"data": {"id": "4"}, "stream": "stream1"}, + ] + ) + .set_expected_catalog( + { + "streams": [ + { + "json_schema": { + "type": "object", + "properties": { + "id": {"type": ["null", "string"]}, + }, + }, + "name": "stream1", + "supported_sync_modes": ["full_refresh"], + } + ] + } + ) + .build() +) diff --git a/airbyte-cdk/python/unit_tests/sources/async_cdk/scenario_based/async_concurrent_stream_source_builder.py b/airbyte-cdk/python/unit_tests/sources/async_cdk/scenario_based/async_concurrent_stream_source_builder.py new file mode 100644 index 000000000000..fa21c31ca082 --- /dev/null +++ b/airbyte-cdk/python/unit_tests/sources/async_cdk/scenario_based/async_concurrent_stream_source_builder.py @@ -0,0 +1,122 @@ +# +# Copyright (c) 2023 Airbyte, Inc., all rights reserved. +# +import logging +from typing import Any, Iterable, List, Mapping, Optional, Tuple, Union + +from airbyte_cdk.models import ConfiguredAirbyteCatalog, ConnectorSpecification, DestinationSyncMode, SyncMode +from airbyte_cdk.sources.async_cdk.abstract_source_async import AsyncAbstractSource +from airbyte_cdk.sources.async_cdk.source_dispatcher import SourceDispatcher +from airbyte_cdk.sources.async_cdk.streams.core_async import AsyncStream +from airbyte_cdk.sources.async_cdk.streams.availability_strategy_async import AsyncAvailabilityStrategy +from airbyte_cdk.sources.message import MessageRepository +from airbyte_cdk.sources.streams.core import StreamData +from airbyte_protocol.models import ConfiguredAirbyteStream +from unit_tests.sources.scenario_based.helpers import NeverLogSliceLogger +from unit_tests.sources.scenario_based.scenario_builder import SourceBuilder + + +class AsyncConcurrentCdkSource(AsyncAbstractSource): + def __init__(self, streams: List[AsyncStream]): + self._streams = streams + super().__init__() + + async def check_connection(self, logger: logging.Logger, config: Mapping[str, Any]) -> Tuple[bool, Optional[Any]]: + # Check is not verified because it is up to the source to implement this method + return True, None + + async def streams(self, config: Mapping[str, Any]) -> List[AsyncStream]: + return self._streams + + def spec(self, *args: Any, **kwargs: Any) -> ConnectorSpecification: + return ConnectorSpecification(connectionSpecification={}) + + def read_catalog(self, catalog_path: str) -> ConfiguredAirbyteCatalog: + return ConfiguredAirbyteCatalog( + streams=[ + ConfiguredAirbyteStream( + stream=s.as_airbyte_stream(), + sync_mode=SyncMode.full_refresh, + destination_sync_mode=DestinationSyncMode.overwrite, + ) + for s in self._streams + ] + ) + + +class ConcurrentSourceBuilder(SourceBuilder[AsyncConcurrentCdkSource]): + def __init__(self): + self._streams: List[AsyncStream] = [] + self._message_repository = None + + def build(self, configured_catalog: Optional[Mapping[str, Any]]) -> SourceDispatcher: + async_source = AsyncConcurrentCdkSource(self._streams) + async_source._streams = self._streams + async_source._message_repository = self._message_repository + async_source._slice_logger = NeverLogSliceLogger() + return SourceDispatcher(async_source) + + def set_streams(self, streams: List[AsyncStream]) -> "ConcurrentSourceBuilder": + self._streams = streams + return self + + def set_message_repository(self, message_repository: MessageRepository) -> "ConcurrentSourceBuilder": + self._message_repository = message_repository + return self + + +class AlwaysAvailableAvailabilityStrategy(AsyncAvailabilityStrategy): + async def check_availability(self, stream: AsyncStream, logger: logging.Logger, source: Optional["Source"]) -> Tuple[bool, Optional[str]]: + return True, None + + +class LocalAsyncStream(AsyncStream): + def __init__( + self, + name: str, + json_schema: Mapping[str, Any], + availability_strategy: Optional[AsyncAvailabilityStrategy], + primary_key: Any, # TODO + cursor_field: Any, # TODO + slices: List[List[Mapping[str, Any]]] + ): + self._name = name + self._json_schema = json_schema + self._availability_strategy = availability_strategy + self._primary_key = primary_key + self._cursor_field = cursor_field + self._slices = slices + + @property + def name(self): + return self._name + + async def read_records( + self, + sync_mode: SyncMode, + cursor_field: Optional[List[str]] = None, + stream_slice: Optional[Mapping[str, Any]] = None, + stream_state: Optional[Mapping[str, Any]] = None, + ) -> Iterable[StreamData]: + if stream_slice: + for record in stream_slice: + yield record + else: + raise NotImplementedError + + async def stream_slices( + self, *, sync_mode: SyncMode, cursor_field: Optional[List[str]] = None, stream_state: Optional[Mapping[str, Any]] = None + ) -> Iterable[Optional[Mapping[str, Any]]]: + for stream_slice in self._slices: + yield stream_slice + + @property + def availability_strategy(self) -> Optional[AsyncAvailabilityStrategy]: + return self._availability_strategy + + @property + def primary_key(self) -> Optional[Union[str, List[str], List[List[str]]]]: + return self._primary_key + + def get_json_schema(self) -> Mapping[str, Any]: + return self._json_schema diff --git a/airbyte-cdk/python/unit_tests/sources/async_cdk/scenario_based/test_async_concurrent_scenarios.py b/airbyte-cdk/python/unit_tests/sources/async_cdk/scenario_based/test_async_concurrent_scenarios.py new file mode 100644 index 000000000000..27ceb279f48b --- /dev/null +++ b/airbyte-cdk/python/unit_tests/sources/async_cdk/scenario_based/test_async_concurrent_scenarios.py @@ -0,0 +1,56 @@ +# +# Copyright (c) 2023 Airbyte, Inc., all rights reserved. +# + +from pathlib import PosixPath + +import pytest +from freezegun import freeze_time +from pytest import LogCaptureFixture +from unit_tests.sources.scenario_based.helpers import verify_discover, verify_read +from unit_tests.sources.scenario_based.scenario_builder import TestScenario +# from unit_tests.sources.streams.concurrent.scenarios.async_incremental_scenarios import ( +# test_incremental_stream_with_slice_boundaries_no_input_state, +# test_incremental_stream_with_slice_boundaries_with_concurrent_state, +# test_incremental_stream_with_slice_boundaries_with_legacy_state, +# test_incremental_stream_without_slice_boundaries_no_input_state, +# test_incremental_stream_without_slice_boundaries_with_concurrent_state, +# test_incremental_stream_without_slice_boundaries_with_legacy_state, +# ) +from unit_tests.sources.async_cdk.scenario_based.async_concurrent_stream_scenarios import ( + test_concurrent_cdk_multiple_streams, + test_concurrent_cdk_partition_raises_exception, + test_concurrent_cdk_single_stream, + test_concurrent_cdk_single_stream_multiple_partitions, + test_concurrent_cdk_single_stream_multiple_partitions_concurrency_level_two, + test_concurrent_cdk_single_stream_with_primary_key, +) + +scenarios = [ + test_concurrent_cdk_single_stream, + test_concurrent_cdk_multiple_streams, + test_concurrent_cdk_single_stream_multiple_partitions, + test_concurrent_cdk_single_stream_multiple_partitions_concurrency_level_two, + test_concurrent_cdk_single_stream_with_primary_key, + test_concurrent_cdk_partition_raises_exception, + # test_incremental_stream_with_slice_boundaries, + # test_incremental_stream_without_slice_boundaries, + # test_incremental_stream_with_many_slices_but_without_slice_boundaries, + # test_incremental_stream_with_slice_boundaries_no_input_state, + # test_incremental_stream_with_slice_boundaries_with_concurrent_state, + # test_incremental_stream_with_slice_boundaries_with_legacy_state, + # test_incremental_stream_without_slice_boundaries_no_input_state, + # test_incremental_stream_without_slice_boundaries_with_concurrent_state, + # test_incremental_stream_without_slice_boundaries_with_legacy_state, +] + + +@pytest.mark.parametrize("scenario", scenarios, ids=[s.name for s in scenarios]) +@freeze_time("2023-06-09T00:00:00Z") +def test_concurrent_read(scenario: TestScenario) -> None: + verify_read(scenario) + + +@pytest.mark.parametrize("scenario", scenarios, ids=[s.name for s in scenarios]) +def test_concurrent_discover(tmp_path: PosixPath, scenario: TestScenario) -> None: + verify_discover(tmp_path, scenario) diff --git a/airbyte-cdk/python/unit_tests/sources/async_cdk/streams/__init__.py b/airbyte-cdk/python/unit_tests/sources/async_cdk/streams/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/airbyte-cdk/python/unit_tests/sources/async_cdk/streams/http/__init__.py b/airbyte-cdk/python/unit_tests/sources/async_cdk/streams/http/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/airbyte-cdk/python/unit_tests/sources/async_cdk/streams/http/test_availability_strategy_async.py b/airbyte-cdk/python/unit_tests/sources/async_cdk/streams/http/test_availability_strategy_async.py new file mode 100644 index 000000000000..830e62f2284a --- /dev/null +++ b/airbyte-cdk/python/unit_tests/sources/async_cdk/streams/http/test_availability_strategy_async.py @@ -0,0 +1,201 @@ +# +# Copyright (c) 2023 Airbyte, Inc., all rights reserved. +# +import asyncio +import logging +from typing import Any, Iterable, List, Mapping, Optional, Tuple + +import aiohttp +import pytest +from aioresponses import aioresponses +from airbyte_cdk.sources.async_cdk.abstract_source_async import AsyncAbstractSource +from airbyte_cdk.sources.async_cdk.source_dispatcher import SourceDispatcher +from airbyte_cdk.sources.async_cdk.streams.http.availability_strategy_async import ( + AsyncHttpAvailabilityStrategy, +) +from airbyte_cdk.sources.async_cdk.streams.http.http_async import AsyncHttpStream +from airbyte_cdk.sources.streams import Stream +from airbyte_cdk.sources.streams.http.utils import HttpError + +logger = logging.getLogger("airbyte") + + +class MockHttpStream(AsyncHttpStream): + url_base = "https://test_base_url.com" + primary_key = "" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.resp_counter = 1 + + async def next_page_token( + self, response: aiohttp.ClientResponse + ) -> Optional[Mapping[str, Any]]: + return None + + def path(self, **kwargs) -> str: + return "" + + async def parse_response( + self, response: aiohttp.ClientResponse, **kwargs + ) -> Iterable[Mapping]: + stub_resp = {"data": self.resp_counter} + self.resp_counter += 1 + yield stub_resp + + def retry_factor(self) -> float: + return 0.01 + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + ("status_code", "expected_is_available", "expected_messages"), + [ + ( + 403, + False, + [ + "This is most likely due to insufficient permissions on the credentials in use.", + ], + ), + (200, True, []), + ], +) +@pytest.mark.parametrize( + ("include_source", "expected_docs_url_messages"), + [ + ( + True, + [ + "Please visit https://docs.airbyte.com/integrations/sources/MockSource to learn more." + ], + ), + (False, ["Please visit the connector's documentation to learn more."]), + ], +) +async def test_default_http_availability_strategy( + status_code, + expected_is_available, + expected_messages, + include_source, + expected_docs_url_messages, +): + class MockListHttpStream(MockHttpStream): + async def read_records(self, *args, **kvargs): + async for record in super().read_records(*args, **kvargs): + yield record + + http_stream = MockListHttpStream() + assert isinstance(http_stream.availability_strategy, AsyncHttpAvailabilityStrategy) + + class MockSource(AsyncAbstractSource): + def __init__(self, streams: List[Stream] = None): + self._streams = streams + super().__init__() + + async def check_connection( + self, logger: logging.Logger, config: Mapping[str, Any] + ) -> Tuple[bool, Optional[Any]]: + return True, "" + + async def streams(self, config: Mapping[str, Any]) -> List[Stream]: + if not self._streams: + raise Exception("Stream is not set") + return self._streams + + await http_stream.ensure_session() + + with aioresponses() as m: + m.get(http_stream.url_base, status=status_code) + + if include_source: + source = SourceDispatcher(MockSource(streams=[http_stream])) + actual_is_available, reason = await http_stream.check_availability( + logger, source.async_source + ) + else: + actual_is_available, reason = await http_stream.check_availability(logger) + + assert expected_is_available == actual_is_available + if expected_is_available: + assert reason is None + else: + all_expected_messages = expected_messages + expected_docs_url_messages + for message in all_expected_messages: + assert message in reason + + await http_stream._session.close() + + +def test_http_availability_raises_unhandled_error(mocker): + http_stream = MockHttpStream() + assert isinstance(http_stream.availability_strategy, AsyncHttpAvailabilityStrategy) + + loop = asyncio.get_event_loop() + loop.run_until_complete(http_stream.ensure_session()) + + with aioresponses() as m: + m.get(http_stream.url_base, status=404) + + with pytest.raises(HttpError): + loop.run_until_complete(http_stream.check_availability(logger)) + + +def test_send_handles_retries_when_checking_availability(caplog): + http_stream = MockHttpStream() + assert isinstance(http_stream.availability_strategy, AsyncHttpAvailabilityStrategy) + + loop = asyncio.get_event_loop() + loop.run_until_complete(http_stream.ensure_session()) + + call_counter = 0 + + def request_callback(*args, **kwargs): + nonlocal call_counter + call_counter += 1 + + with aioresponses() as m: + m.get(http_stream.url_base, status=429, callback=request_callback) + m.get(http_stream.url_base, status=503, callback=request_callback) + m.get(http_stream.url_base, status=200, callback=request_callback) + + with caplog.at_level(logging.INFO): + stream_is_available, _ = loop.run_until_complete( + http_stream.check_availability(logger) + ) + + assert stream_is_available + assert call_counter == 3 + for message in [ + "Caught retryable error", + "Response Code: 429", + "Response Code: 503", + ]: + assert message in caplog.text + + +def test_http_availability_strategy_on_empty_stream(mocker): + empty_stream_called = False + + async def empty_aiter(*args, **kwargs): + nonlocal empty_stream_called + empty_stream_called = True + yield + + class MockEmptyHttpStream(mocker.MagicMock, MockHttpStream): + def __init__(self, *args, **kvargs): + mocker.MagicMock.__init__(self) + self.read_records = empty_aiter + + empty_stream = MockEmptyHttpStream() + assert isinstance(empty_stream, AsyncHttpStream) + assert isinstance(empty_stream.availability_strategy, AsyncHttpAvailabilityStrategy) + + logger = logging.getLogger("airbyte.test-source") + loop = asyncio.get_event_loop() + stream_is_available, _ = loop.run_until_complete( + empty_stream.check_availability(logger) + ) + + assert stream_is_available + assert empty_stream_called diff --git a/airbyte-cdk/python/unit_tests/sources/async_cdk/streams/http/test_http_async.py b/airbyte-cdk/python/unit_tests/sources/async_cdk/streams/http/test_http_async.py new file mode 100644 index 000000000000..e62bcf5cd82b --- /dev/null +++ b/airbyte-cdk/python/unit_tests/sources/async_cdk/streams/http/test_http_async.py @@ -0,0 +1,964 @@ +# +# Copyright (c) 2023 Airbyte, Inc., all rights reserved. +# + + +import asyncio +import json +from http import HTTPStatus +from typing import Any, Iterable, Mapping, Optional +from unittest.mock import AsyncMock, MagicMock, patch +from yarl import URL + +import aiohttp +import pytest +from aioresponses import CallbackResult, aioresponses +from airbyte_cdk.models import SyncMode +from airbyte_cdk.sources.async_cdk.streams.http.http_async import ( + AsyncHttpStream, + AsyncHttpSubStream, +) +from airbyte_cdk.sources.streams.http.exceptions import RequestBodyException +from airbyte_cdk.sources.async_cdk.streams.http.exceptions_async import ( + AsyncDefaultBackoffException, + AsyncUserDefinedBackoffException, +) +from airbyte_cdk.sources.streams.http.auth import NoAuth +from airbyte_cdk.sources.streams.http.auth import ( + TokenAuthenticator as HttpTokenAuthenticator, +) +from airbyte_cdk.sources.streams.http.utils import HttpError + + +class StubBasicReadHttpStream(AsyncHttpStream): + url_base = "https://test_base_url.com" + primary_key = "" + + def __init__(self, deduplicate_query_params: bool = False, **kwargs): + super().__init__(**kwargs) + self.resp_counter = 1 + self._deduplicate_query_params = deduplicate_query_params + + async def next_page_token( + self, response: aiohttp.ClientResponse + ) -> Optional[Mapping[str, Any]]: + return None + + def path(self, **kwargs) -> str: + return "" + + async def parse_response( + self, response: aiohttp.ClientResponse, **kwargs + ) -> Iterable[Mapping]: + stubResp = {"data": self.resp_counter} + self.resp_counter += 1 + yield stubResp + + def must_deduplicate_query_params(self) -> bool: + return self._deduplicate_query_params + + +def test_default_authenticator(): + stream = StubBasicReadHttpStream() + assert isinstance(stream.authenticator, NoAuth) + + +def test_http_token_authenticator(): + stream = StubBasicReadHttpStream(authenticator=HttpTokenAuthenticator("test-token")) + assert isinstance(stream.authenticator, HttpTokenAuthenticator) + + +def test_request_kwargs_used(mocker): + loop = asyncio.get_event_loop() + stream = StubBasicReadHttpStream() + loop.run_until_complete(stream.ensure_session()) + request_kwargs = {"chunked": True, "compress": True} + mocker.patch.object(stream, "request_kwargs", return_value=request_kwargs) + + with aioresponses() as m: + m.get(stream.url_base, status=200) + loop.run_until_complete(read_records(stream)) + + m.assert_any_call(stream.url_base, "GET", **request_kwargs) + m.assert_called_once() + + loop.run_until_complete( + stream._session.close() + ) # TODO - find a way to not manually close after each test + + +async def read_records(stream, sync_mode=SyncMode.full_refresh, stream_slice=None): + records = [] + async for record in stream.read_records( + sync_mode=sync_mode, stream_slice=stream_slice + ): + records.append(record) + return records + + +def test_stub_basic_read_http_stream_read_records(mocker): + loop = asyncio.get_event_loop() + stream = StubBasicReadHttpStream() + blank_response = ( + {} + ) # Send a blank response is fine as we ignore the response in `parse_response anyway. + mocker.patch.object( + StubBasicReadHttpStream, "_send_request", return_value=blank_response + ) + + records = loop.run_until_complete(read_records(stream)) + + assert [{"data": 1}] == records + + +class StubNextPageTokenHttpStream(StubBasicReadHttpStream): + current_page = 0 + + def __init__(self, pages: int = 5): + super().__init__() + self._pages = pages + + async def next_page_token( + self, response: aiohttp.ClientResponse + ) -> Optional[Mapping[str, Any]]: + while self.current_page < self._pages: + page_token = {"page": self.current_page} + self.current_page += 1 + return page_token + return None + + +def test_next_page_token_is_input_to_other_methods(mocker): + """Validates that the return value from next_page_token is passed into other methods that need it like request_params, headers, body, etc..""" + pages = 5 + stream = StubNextPageTokenHttpStream(pages=pages) + blank_response = ( + {} + ) # Send a blank response is fine as we ignore the response in `parse_response anyway. + mocker.patch.object( + StubNextPageTokenHttpStream, "_send_request", return_value=blank_response + ) + + methods = ["request_params", "request_headers", "request_body_json"] + for method in methods: + # Wrap all methods we're interested in testing with mocked objects so we can later spy on their input args and verify they were what we expect + mocker.patch.object(stream, method, wraps=getattr(stream, method)) + + loop = asyncio.get_event_loop() + records = loop.run_until_complete(read_records(stream)) + + # Since we have 5 pages, we expect 5 tokens which are {"page":1}, {"page":2}, etc... + expected_next_page_tokens = [{"page": i} for i in range(pages)] + for method in methods: + # First assert that they were called with no next_page_token. This is the first call in the pagination loop. + getattr(stream, method).assert_any_call( + next_page_token=None, stream_slice=None, stream_state={} + ) + for token in expected_next_page_tokens: + # Then verify that each method + getattr(stream, method).assert_any_call( + next_page_token=token, stream_slice=None, stream_state={} + ) + + expected = [ + {"data": 1}, + {"data": 2}, + {"data": 3}, + {"data": 4}, + {"data": 5}, + {"data": 6}, + ] + + assert expected == records + + +class StubBadUrlHttpStream(StubBasicReadHttpStream): + url_base = "bad_url" + + +def test_stub_bad_url_http_stream_read_records(): + stream = StubBadUrlHttpStream() + loop = asyncio.get_event_loop() + with pytest.raises(aiohttp.client_exceptions.InvalidURL): + loop.run_until_complete(read_records(stream)) + + +class StubCustomBackoffHttpStream(StubBasicReadHttpStream): + def backoff_time(self, response: aiohttp.ClientResponse) -> Optional[float]: + return 0.5 + + +def test_stub_custom_backoff_http_stream(mocker): + mocker.patch("time.sleep", lambda x: None) + stream = StubCustomBackoffHttpStream() + loop = asyncio.get_event_loop() + loop.run_until_complete(stream.ensure_session()) + call_counter = 0 + + def request_callback(*args, **kwargs): + nonlocal call_counter + call_counter += 1 + + with aioresponses() as m: + m.get(stream.url_base, status=429, repeat=True, callback=request_callback) + + with pytest.raises(AsyncUserDefinedBackoffException): + loop.run_until_complete(read_records(stream)) + + assert call_counter == stream.max_retries + 1 + loop.run_until_complete(stream._session.close()) + + +@pytest.mark.parametrize("retries", [-20, -1, 0, 1, 2, 10]) +def test_stub_custom_backoff_http_stream_retries(mocker, retries): + mocker.patch("time.sleep", lambda x: None) + + class StubCustomBackoffHttpStreamRetries(StubCustomBackoffHttpStream): + @property + def max_retries(self): + return retries + + def request_callback(*args, **kwargs): + nonlocal call_counter + call_counter += 1 + + stream = StubCustomBackoffHttpStreamRetries() + loop = asyncio.get_event_loop() + loop.run_until_complete(stream.ensure_session()) + call_counter = 0 + + with aioresponses() as m: + m.get(stream.url_base, status=429, repeat=True, callback=request_callback) + + with pytest.raises(AsyncUserDefinedBackoffException) as excinfo: + loop.run_until_complete(read_records(stream)) + assert isinstance(excinfo.value.request, aiohttp.ClientRequest) + assert isinstance(excinfo.value.response, aiohttp.ClientResponse) + + if retries <= 0: + m.assert_called_once() + else: + assert call_counter == stream.max_retries + 1 + loop.run_until_complete(stream._session.close()) + + +def test_stub_custom_backoff_http_stream_endless_retries(mocker): + mocker.patch("time.sleep", lambda x: None) + + class StubCustomBackoffHttpStreamRetries(StubCustomBackoffHttpStream): + @property + def max_retries(self): + return None + + stream = StubCustomBackoffHttpStreamRetries() + loop = asyncio.get_event_loop() + loop.run_until_complete(stream.ensure_session()) + infinite_number = 20 + call_counter = 0 + + with aioresponses() as m: + + def request_callback(*args, **kwargs): + nonlocal call_counter + call_counter += 1 + if call_counter > infinite_number: + # Simulate a different response or a break in the pattern + # to stop the infinite retries + raise RuntimeError("End of retries") + + m.get( + stream.url_base, + status=HTTPStatus.TOO_MANY_REQUESTS, + repeat=True, + callback=request_callback, + ) + + # Expecting mock object to raise a RuntimeError when the end of side_effect list parameter reached. + with pytest.raises(RuntimeError): + loop.run_until_complete(read_records(stream)) + + assert call_counter == infinite_number + 1 + loop.run_until_complete(stream._session.close()) + + +@pytest.mark.parametrize("http_code", [400, 401, 403]) +def test_4xx_error_codes_http_stream(http_code): + stream = StubCustomBackoffHttpStream() + loop = asyncio.get_event_loop() + loop.run_until_complete(stream.ensure_session()) + + with aioresponses() as m: + m.get(stream.url_base, status=http_code, repeat=True) + + with pytest.raises(HttpError): + loop.run_until_complete(read_records(stream)) + + loop.run_until_complete(stream._session.close()) + + +class AutoFailFalseHttpStream(StubBasicReadHttpStream): + raise_on_http_errors = False + max_retries = 3 + retry_factor = 0.01 + + +def test_raise_on_http_errors_off_429(): + stream = AutoFailFalseHttpStream() + loop = asyncio.get_event_loop() + loop.run_until_complete(stream.ensure_session()) + + with aioresponses() as m: + m.get(stream.url_base, status=429, repeat=True) + with pytest.raises(AsyncDefaultBackoffException): + loop.run_until_complete(read_records(stream)) + + loop.run_until_complete(stream._session.close()) + + +@pytest.mark.parametrize("status_code", [500, 501, 503, 504]) +def test_raise_on_http_errors_off_5xx(status_code): + stream = AutoFailFalseHttpStream() + loop = asyncio.get_event_loop() + loop.run_until_complete(stream.ensure_session()) + call_counter = 0 + + def request_callback(*args, **kwargs): + nonlocal call_counter + call_counter += 1 + + with aioresponses() as m: + m.get( + stream.url_base, status=status_code, repeat=True, callback=request_callback + ) + with pytest.raises(AsyncDefaultBackoffException): + loop.run_until_complete(read_records(stream)) + + assert call_counter == stream.max_retries + 1 + loop.run_until_complete(stream._session.close()) + + +@pytest.mark.parametrize("status_code", [400, 401, 402, 403, 416]) +def test_raise_on_http_errors_off_non_retryable_4xx(status_code): + stream = AutoFailFalseHttpStream() + loop = asyncio.get_event_loop() + loop.run_until_complete(stream.ensure_session()) + + with aioresponses() as m: + m.get(stream.url_base, status=status_code, repeat=True) + response = loop.run_until_complete( + stream._send_request(aiohttp.ClientRequest("GET", URL(stream.url_base)), {}) + ) + + assert response.status == status_code + loop.run_until_complete(stream._session.close()) + + +@pytest.mark.parametrize( + "error", + ( + aiohttp.ClientPayloadError, + aiohttp.ServerDisconnectedError, + aiohttp.ServerConnectionError, + aiohttp.ServerTimeoutError, + ), +) +def test_raise_on_http_errors(error): + stream = AutoFailFalseHttpStream() + loop = asyncio.get_event_loop() + loop.run_until_complete(stream.ensure_session()) + call_counter = 0 + + def request_callback(*args, **kwargs): + nonlocal call_counter + call_counter += 1 + + with aioresponses() as m: + m.get( + stream.url_base, repeat=True, callback=request_callback, exception=error() + ) + + with pytest.raises(error): + loop.run_until_complete(read_records(stream)) + + assert call_counter == stream.max_retries + 1 + loop.run_until_complete(stream._session.close()) + + +class PostHttpStream(StubBasicReadHttpStream): + http_method = "POST" + + async def parse_response( + self, response: aiohttp.ClientResponse, **kwargs + ) -> Iterable[Mapping]: + """Returns response data as is""" + yield response.json() + + +class TestRequestBody: + """Suite of different tests for request bodies""" + + json_body = {"key": "value"} + data_body = "key:value" + form_body = {"key1": "value1", "key2": 1234} + urlencoded_form_body = "key1=value1&key2=1234" + + def request2response(self, **kwargs): + """Callback function to handle request and return mock response.""" + body = kwargs.get("data") + headers = kwargs.get("headers", {}) + return { + "body": json.dumps(body) if isinstance(body, dict) else body, + "content_type": headers.get("Content-Type"), + } + + def test_json_body(self, mocker): + stream = PostHttpStream() + mocker.patch.object(stream, "request_body_json", return_value=self.json_body) + loop = asyncio.get_event_loop() + loop.run_until_complete(stream.ensure_session()) + + with aioresponses() as m: + m.post( + stream.url_base, + payload=self.request2response( + data=self.json_body, headers={"Content-Type": "application/json"} + ), + ) + + response = [] + for r in loop.run_until_complete(read_records(stream)): + response.append(loop.run_until_complete(r)) + + assert response[0]["content_type"] == "application/json" + assert json.loads(response[0]["body"]) == self.json_body + loop.run_until_complete(stream._session.close()) + + def test_text_body(self, mocker): + stream = PostHttpStream() + mocker.patch.object(stream, "request_body_data", return_value=self.data_body) + loop = asyncio.get_event_loop() + loop.run_until_complete(stream.ensure_session()) + + with aioresponses() as m: + m.post(stream.url_base, payload=self.request2response(data=self.data_body)) + + response = [] + for r in loop.run_until_complete(read_records(stream)): + response.append(loop.run_until_complete(r)) + + assert response[0]["content_type"] is None + assert response[0]["body"] == self.data_body + loop.run_until_complete(stream._session.close()) + + def test_form_body(self, mocker): + raise NotImplementedError("This is not supported for the async flow yet.") + + def test_text_json_body(self, mocker): + """checks a exception if both functions were overridden""" + stream = PostHttpStream() + loop = asyncio.get_event_loop() + loop.run_until_complete(stream.ensure_session()) + + mocker.patch.object(stream, "request_body_data", return_value=self.data_body) + mocker.patch.object(stream, "request_body_json", return_value=self.json_body) + + with aioresponses() as m: + m.post(stream.url_base, payload=self.request2response(data=self.data_body)) + with pytest.raises(RequestBodyException): + loop.run_until_complete(read_records(stream)) + + loop.run_until_complete(stream._session.close()) + + def test_body_for_all_methods(self, mocker, requests_mock): + """Stream must send a body for GET/POST/PATCH/PUT methods only""" + stream = PostHttpStream() + loop = asyncio.get_event_loop() + loop.run_until_complete(stream.ensure_session()) + + methods = { + "POST": True, + "PUT": True, + "PATCH": True, + "GET": True, + "DELETE": False, + "OPTIONS": False, + } + for method, with_body in methods.items(): + stream.http_method = method + mocker.patch.object( + stream, "request_body_data", return_value=self.data_body + ) + + with aioresponses() as m: + if method == "POST": + request = m.post + elif method == "PUT": + request = m.put + elif method == "PATCH": + request = m.patch + elif method == "GET": + request = m.get + elif method == "DELETE": + request = m.delete + elif method == "OPTIONS": + request = m.options + + request( + stream.url_base, payload=self.request2response(data=self.data_body) + ) + + response = [] + for r in loop.run_until_complete(read_records(stream)): + response.append(loop.run_until_complete(r)) + + # The requests library flow strips the body where `with_body` is False, but + # aiohttp does not. + assert response[0]["body"] == self.data_body + + loop.run_until_complete(stream._session.close()) + + +class CacheHttpStream(StubBasicReadHttpStream): + use_cache = True + + +class CacheHttpSubStream(AsyncHttpSubStream): + url_base = "https://example.com" + primary_key = "" + + def __init__(self, parent): + super().__init__(parent=parent) + + async def parse_response( + self, response: aiohttp.ClientResponse, **kwargs + ) -> Iterable[Mapping]: + yield None + + def next_page_token( + self, response: aiohttp.ClientResponse + ) -> Optional[Mapping[str, Any]]: + return None + + def path(self, **kwargs) -> str: + return "" + + +def test_caching_filename(): + stream = CacheHttpStream() + assert stream.cache_filename == f"{stream.name}.sqlite" + + +def test_caching_sessions_are_different(): + stream_1 = CacheHttpStream() + stream_2 = CacheHttpStream() + loop = asyncio.get_event_loop() + loop.run_until_complete(stream_1.ensure_session()) + loop.run_until_complete(stream_2.ensure_session()) + + assert stream_1._session != stream_2._session + assert stream_1.cache_filename == stream_2.cache_filename + loop.run_until_complete(stream_1._session.close()) + loop.run_until_complete(stream_2._session.close()) + + +def test_parent_attribute_exist(): + parent_stream = CacheHttpStream() + child_stream = CacheHttpSubStream(parent=parent_stream) + + assert child_stream.parent == parent_stream + + +def test_that_response_was_cached(mocker): + stream = CacheHttpStream() + loop = asyncio.get_event_loop() + loop.run_until_complete(stream.ensure_session()) + loop.run_until_complete(stream.clear_cache()) + + mocker.patch.object(stream, "url_base", "https://google.com/") + + with aioresponses() as m1: + m1.get(stream.url_base) + records = loop.run_until_complete(read_records(stream)) + m1.assert_called_once() + + with aioresponses() as m2: + m2.get(stream.url_base) + new_records = loop.run_until_complete(read_records(stream)) + m2.assert_not_called() + + assert len(records) == len(new_records) + loop.run_until_complete(stream._session.close()) + + +class CacheHttpStreamWithSlices(CacheHttpStream): + paths = ["", "search"] + + def path(self, stream_slice: Mapping[str, Any] = None, **kwargs) -> str: + return f'{stream_slice["path"]}' if stream_slice else "" + + async def stream_slices(self, **kwargs) -> Iterable[Optional[Mapping[str, Any]]]: + for path in self.paths: + yield {"path": path} + + async def parse_response( + self, response: aiohttp.ClientResponse, **kwargs + ) -> Iterable[Mapping]: + yield {"value": len(await response.text())} + + +async def stream_slices(stream, sync_mode=SyncMode.full_refresh): + slices = [] + async for s in stream.stream_slices(sync_mode=sync_mode): + slices.append(s) + return slices + + +@patch("airbyte_cdk.sources.streams.core.logging", MagicMock()) +def test_using_cache(mocker): + parent_stream = CacheHttpStreamWithSlices() + loop = asyncio.get_event_loop() + loop.run_until_complete(parent_stream.ensure_session()) + + mocker.patch.object(parent_stream, "url_base", "https://google.com/") + + call_counter = 0 + + def request_callback(*args, **kwargs): + nonlocal call_counter + call_counter += 1 + + async def get_urls(stream): + urls = [] + async for u in stream._session.cache.get_urls(): + urls.append(u) + return urls + + with aioresponses() as m: + # Set up the mocks + slices = loop.run_until_complete( + stream_slices(parent_stream, sync_mode=SyncMode.full_refresh) + ) + m.get(parent_stream.url_base, callback=request_callback) + m.get(f"{parent_stream.url_base}search", callback=request_callback) + + loop.run_until_complete(parent_stream.clear_cache()) + assert call_counter == 0 + + # Get the parent stream's records; the responses should be cached + loop.run_until_complete(read_records(parent_stream, stream_slice=slices[0])) + loop.run_until_complete(read_records(parent_stream, stream_slice=slices[1])) + + assert call_counter == 2 + urls = loop.run_until_complete(get_urls(parent_stream)) + assert len(urls) == 2 + + child_stream = CacheHttpSubStream(parent=parent_stream) + loop.run_until_complete(child_stream.ensure_session()) + + # child_stream.stream_slices will call `parent.read_records`, however this shouldn't + # result in a new request to the 3rd party since the response has been cached + loop.run_until_complete( + stream_slices(child_stream, sync_mode=SyncMode.full_refresh) + ) + + assert call_counter == 2 + urls = loop.run_until_complete(get_urls(parent_stream)) + assert len(urls) == 2 + assert URL("https://google.com/") in urls + assert URL("https://google.com/search") in urls + + loop.run_until_complete(parent_stream._session.close()) + loop.run_until_complete(child_stream._session.close()) + + +class AutoFailTrueHttpStream(StubBasicReadHttpStream): + raise_on_http_errors = True + + +@pytest.mark.parametrize("status_code", range(400, 600)) +def test_send_raise_on_http_errors_logs(mocker, status_code): + mocker.patch.object(AutoFailTrueHttpStream, "logger") + mocker.patch.object( + AutoFailTrueHttpStream, "should_retry", mocker.Mock(return_value=False) + ) + + stream = AutoFailTrueHttpStream() + loop = asyncio.get_event_loop() + loop.run_until_complete(stream.ensure_session()) + + req = aiohttp.ClientRequest("GET", URL(stream.url_base)) + + with aioresponses() as m: + m.get(stream.url_base, status=status_code, repeat=True, payload="text") + + with pytest.raises(HttpError): + response = loop.run_until_complete(stream._send_request(req, {})) + stream.logger.error.assert_called_with("text") + assert response.status == status_code + + loop.run_until_complete(stream._session.close()) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "api_response, expected_message", + [ + ({"error": "something broke"}, "something broke"), + ({"error": {"message": "something broke"}}, "something broke"), + ({"error": "err-001", "message": "something broke"}, "something broke"), + ({"failure": {"message": "something broke"}}, "something broke"), + ( + { + "error": { + "errors": [ + {"message": "one"}, + {"message": "two"}, + {"message": "three"}, + ] + } + }, + "one, two, three", + ), + ({"errors": ["one", "two", "three"]}, "one, two, three"), + ({"messages": ["one", "two", "three"]}, "one, two, three"), + ( + {"errors": [{"message": "one"}, {"message": "two"}, {"message": "three"}]}, + "one, two, three", + ), + ( + {"error": [{"message": "one"}, {"message": "two"}, {"message": "three"}]}, + "one, two, three", + ), + ( + {"errors": [{"error": "one"}, {"error": "two"}, {"error": "three"}]}, + "one, two, three", + ), + ( + { + "failures": [ + {"message": "one"}, + {"message": "two"}, + {"message": "three"}, + ] + }, + "one, two, three", + ), + (["one", "two", "three"], "one, two, three"), + ([{"error": "one"}, {"error": "two"}, {"error": "three"}], "one, two, three"), + ({"error": True}, None), + ({"something_else": "hi"}, None), + ({}, None), + ], +) +async def test_default_parse_response_error_message( + api_response: dict, expected_message: Optional[str] +): + stream = StubBasicReadHttpStream() + response = MagicMock() + response.ok = False + response.json = AsyncMock(return_value=api_response) + response.text = AsyncMock() + response.content = AsyncMock() + exc_message = None + try: + await stream.handle_response_with_error(response) + except HttpError as exc: + exc_message = stream.get_error_display_message(exc) + + assert exc_message == expected_message + +""" + stream = StubBasicReadHttpStream() + response = MagicMock() + response.json.return_value = api_response + + message = stream.parse_response_error_message(response) + assert message == expected_message +""" + + +@pytest.mark.asyncio +async def test_default_parse_response_error_message_not_json(): + stream = StubBasicReadHttpStream() + await stream.ensure_session() + url = "mock://test.com/not_json" + + with aioresponses() as m: + m.get( + url, + callback=lambda *_, **__: CallbackResult( + status=400, body="this is not json" + ), + ) + with pytest.raises(HttpError): + response = await stream._send_request( + aiohttp.ClientRequest("GET", URL(url)), {} + ) + message = await stream.parse_response_error_message(response) + assert message is None + await stream._session.close() + + +def test_default_get_error_display_message_handles_http_error(mocker): + stream = StubBasicReadHttpStream() + + mocker.patch.object( + stream, "parse_error_message", return_value="my custom message" + ) + + non_http_err_msg = stream.get_error_display_message(RuntimeError("not me")) + assert non_http_err_msg is None + + req = aiohttp.ClientRequest("GET", URL(stream.url_base)) + + error = HttpError(aiohttp_error=aiohttp.ClientResponseError(request_info=req.request_info, history=(), status=400, message="", headers={})) + http_err_msg = stream.get_error_display_message(error) + assert http_err_msg == "my custom message" + + +@pytest.mark.parametrize( + "test_name, base_url, path, expected_full_url", + [ + ( + "test_no_slashes", + "https://airbyte.io", + "my_endpoint", + "https://airbyte.io/my_endpoint", + ), + ( + "test_trailing_slash_on_base_url", + "https://airbyte.io/", + "my_endpoint", + "https://airbyte.io/my_endpoint", + ), + ( + "test_trailing_slash_on_base_url_and_leading_slash_on_path", + "https://airbyte.io/", + "/my_endpoint", + "https://airbyte.io/my_endpoint", + ), + ( + "test_leading_slash_on_path", + "https://airbyte.io", + "/my_endpoint", + "https://airbyte.io/my_endpoint", + ), + ( + "test_trailing_slash_on_path", + "https://airbyte.io", + "/my_endpoint/", + "https://airbyte.io/my_endpoint/", + ), + ( + "test_nested_path_no_leading_slash", + "https://airbyte.io", + "v1/my_endpoint", + "https://airbyte.io/v1/my_endpoint", + ), + ( + "test_nested_path_with_leading_slash", + "https://airbyte.io", + "/v1/my_endpoint", + "https://airbyte.io/v1/my_endpoint", + ), + ], +) +def test_join_url(test_name, base_url, path, expected_full_url): + actual_url = AsyncHttpStream._join_url(base_url, path) + assert actual_url == expected_full_url + + +@pytest.mark.parametrize( + "deduplicate_query_params, path, params, expected_url", + [ + pytest.param( + True, + "v1/endpoint?param1=value1", + {}, + "https://test_base_url.com/v1/endpoint?param1=value1", + id="test_params_only_in_path", + ), + pytest.param( + True, + "v1/endpoint", + {"param1": "value1"}, + "https://test_base_url.com/v1/endpoint?param1=value1", + id="test_params_only_in_path", + ), + pytest.param( + True, + "v1/endpoint", + None, + "https://test_base_url.com/v1/endpoint", + id="test_params_is_none_and_no_params_in_path", + ), + pytest.param( + True, + "v1/endpoint?param1=value1", + None, + "https://test_base_url.com/v1/endpoint?param1=value1", + id="test_params_is_none_and_no_params_in_path", + ), + pytest.param( + True, + "v1/endpoint?param1=value1", + {"param2": "value2"}, + "https://test_base_url.com/v1/endpoint?param1=value1¶m2=value2", + id="test_no_duplicate_params", + ), + pytest.param( + True, + "v1/endpoint?param1=value1", + {"param1": "value1"}, + "https://test_base_url.com/v1/endpoint?param1=value1", + id="test_duplicate_params_same_value", + ), + pytest.param( + True, + "v1/endpoint?param1=1", + {"param1": 1}, + "https://test_base_url.com/v1/endpoint?param1=1", + id="test_duplicate_params_same_value_not_string", + ), + pytest.param( + True, + "v1/endpoint?param1=value1", + {"param1": "value2"}, + "https://test_base_url.com/v1/endpoint?param1=value1¶m1=value2", + id="test_duplicate_params_different_value", + ), + pytest.param( + False, + "v1/endpoint?param1=value1", + {"param1": "value2"}, + "https://test_base_url.com/v1/endpoint?param1=value1¶m1=value2", + id="test_same_params_different_value_no_deduplication", + ), + pytest.param( + False, + "v1/endpoint?param1=value1", + {"param1": "value1"}, + "https://test_base_url.com/v1/endpoint?param1=value1¶m1=value1", + id="test_same_params_same_value_no_deduplication", + ), + ], +) +def test_duplicate_request_params_are_deduped( + deduplicate_query_params, path, params, expected_url +): + stream = StubBasicReadHttpStream(deduplicate_query_params) + + if expected_url is None: + with pytest.raises(ValueError): + stream._create_prepared_request(path=path, params=params) + else: + prepared_request = stream._create_prepared_request(path=path, params=params) + assert str(prepared_request.url) == expected_url + + +def test_connection_pool(): + stream = StubBasicReadHttpStream(authenticator=HttpTokenAuthenticator("test-token")) + loop = asyncio.get_event_loop() + loop.run_until_complete(stream.ensure_session()) + assert stream._session.connector.limit == 20 + loop.run_until_complete(stream._session.close()) diff --git a/airbyte-cdk/python/unit_tests/sources/async_cdk/streams/test_availability_strategy_async.py b/airbyte-cdk/python/unit_tests/sources/async_cdk/streams/test_availability_strategy_async.py new file mode 100644 index 000000000000..4f62ddbf73fb --- /dev/null +++ b/airbyte-cdk/python/unit_tests/sources/async_cdk/streams/test_availability_strategy_async.py @@ -0,0 +1,72 @@ +# +# Copyright (c) 2023 Airbyte, Inc., all rights reserved. +# + +import asyncio +import logging +from typing import Any, Iterable, List, Mapping, Optional, Tuple, Union + +from airbyte_cdk.models import SyncMode +from airbyte_cdk.sources import Source +from airbyte_cdk.sources.async_cdk.streams.availability_strategy_async import AsyncAvailabilityStrategy +from airbyte_cdk.sources.streams import Stream +from airbyte_cdk.sources.streams.core import StreamData + +logger = logging.getLogger("airbyte") + + +class MockStream(Stream): + def __init__(self, name: str) -> Stream: + self._name = name + + @property + def name(self) -> str: + return self._name + + @property + def primary_key(self) -> Optional[Union[str, List[str], List[List[str]]]]: + pass + + def read_records( + self, + sync_mode: SyncMode, + cursor_field: List[str] = None, + stream_slice: Mapping[str, Any] = None, + stream_state: Mapping[str, Any] = None, + ) -> Iterable[StreamData]: + pass + + +def test_no_availability_strategy(): + stream_1 = MockStream("stream") + assert stream_1.availability_strategy is None + + stream_1_is_available, _ = stream_1.check_availability(logger) + assert stream_1_is_available + + +def test_availability_strategy(): + class MockAvailabilityStrategy(AsyncAvailabilityStrategy): + async def check_availability(self, stream: Stream, logger: logging.Logger, source: Optional[Source]) -> Tuple[bool, any]: + if stream.name == "available_stream": + return True, None + return False, f"Could not reach stream '{stream.name}'." + + class MockStreamWithAvailabilityStrategy(MockStream): + @property + def availability_strategy(self) -> Optional["AsyncAvailabilityStrategy"]: + return MockAvailabilityStrategy() + + stream_1 = MockStreamWithAvailabilityStrategy("available_stream") + stream_2 = MockStreamWithAvailabilityStrategy("unavailable_stream") + loop = asyncio.get_event_loop() + + for stream in [stream_1, stream_2]: + assert isinstance(stream.availability_strategy, MockAvailabilityStrategy) + + stream_1_is_available, _ = loop.run_until_complete(stream_1.check_availability(logger)) + assert stream_1_is_available + + stream_2_is_available, reason = loop.run_until_complete(stream_2.check_availability(logger)) + assert not stream_2_is_available + assert "Could not reach stream 'unavailable_stream'" in reason diff --git a/airbyte-cdk/python/unit_tests/sources/async_cdk/streams/test_streams_core_async.py b/airbyte-cdk/python/unit_tests/sources/async_cdk/streams/test_streams_core_async.py new file mode 100644 index 000000000000..03d794f1c3a4 --- /dev/null +++ b/airbyte-cdk/python/unit_tests/sources/async_cdk/streams/test_streams_core_async.py @@ -0,0 +1,184 @@ +# +# Copyright (c) 2023 Airbyte, Inc., all rights reserved. +# + + +from typing import Any, Iterable, List, Mapping +from unittest import mock + +import pytest +from airbyte_cdk.models import AirbyteStream, SyncMode +from airbyte_cdk.sources.async_cdk.streams.core_async import AsyncStream + + +class StreamStubFullRefresh(AsyncStream): + """ + Stub full refresh class to assist with testing. + """ + + async def read_records( + self, + sync_mode: SyncMode, + cursor_field: List[str] = None, + stream_slice: Mapping[str, Any] = None, + stream_state: Mapping[str, Any] = None, + ) -> Iterable[Mapping[str, Any]]: + pass + + primary_key = None + + +def test_as_airbyte_stream_full_refresh(mocker): + """ + Should return an full refresh AirbyteStream with information matching the + provided Stream interface. + """ + test_stream = StreamStubFullRefresh() + + mocker.patch.object(StreamStubFullRefresh, "get_json_schema", return_value={}) + airbyte_stream = test_stream.as_airbyte_stream() + + exp = AirbyteStream(name="stream_stub_full_refresh", json_schema={}, supported_sync_modes=[SyncMode.full_refresh]) + assert exp == airbyte_stream + + +class StreamStubIncremental(AsyncStream): + """ + Stub full incremental class to assist with testing. + """ + + async def read_records( + self, + sync_mode: SyncMode, + cursor_field: List[str] = None, + stream_slice: Mapping[str, Any] = None, + stream_state: Mapping[str, Any] = None, + ) -> Iterable[Mapping[str, Any]]: + pass + + cursor_field = "test_cursor" + primary_key = "primary_key" + namespace = "test_namespace" + + +class StreamStubIncrementalEmptyNamespace(AsyncStream): + """ + Stub full incremental class, with empty namespace, to assist with testing. + """ + + async def read_records( + self, + sync_mode: SyncMode, + cursor_field: List[str] = None, + stream_slice: Mapping[str, Any] = None, + stream_state: Mapping[str, Any] = None, + ) -> Iterable[Mapping[str, Any]]: + pass + + cursor_field = "test_cursor" + primary_key = "primary_key" + namespace = "" + + +def test_as_airbyte_stream_incremental(mocker): + """ + Should return an incremental refresh AirbyteStream with information matching + the provided Stream interface. + """ + test_stream = StreamStubIncremental() + + mocker.patch.object(StreamStubIncremental, "get_json_schema", return_value={}) + airbyte_stream = test_stream.as_airbyte_stream() + + exp = AirbyteStream( + name="stream_stub_incremental", + namespace="test_namespace", + json_schema={}, + supported_sync_modes=[SyncMode.full_refresh, SyncMode.incremental], + default_cursor_field=["test_cursor"], + source_defined_cursor=True, + source_defined_primary_key=[["primary_key"]], + ) + assert exp == airbyte_stream + + +def test_supports_incremental_cursor_set(): + """ + Should return true if cursor is set. + """ + test_stream = StreamStubIncremental() + test_stream.cursor_field = "test_cursor" + + assert test_stream.supports_incremental + + +def test_supports_incremental_cursor_not_set(): + """ + Should return false if cursor is not. + """ + test_stream = StreamStubFullRefresh() + + assert not test_stream.supports_incremental + + +def test_namespace_set(): + """ + Should allow namespace property to be set. + """ + test_stream = StreamStubIncremental() + + assert test_stream.namespace == "test_namespace" + + +def test_namespace_set_to_empty_string(mocker): + """ + Should not set namespace property if equal to empty string. + """ + test_stream = StreamStubIncremental() + + mocker.patch.object(StreamStubIncremental, "get_json_schema", return_value={}) + mocker.patch.object(StreamStubIncremental, "namespace", "") + + airbyte_stream = test_stream.as_airbyte_stream() + + exp = AirbyteStream( + name="stream_stub_incremental", + json_schema={}, + supported_sync_modes=[SyncMode.full_refresh, SyncMode.incremental], + default_cursor_field=["test_cursor"], + source_defined_cursor=True, + source_defined_primary_key=[["primary_key"]], + namespace=None, + ) + assert exp == airbyte_stream + + +def test_namespace_not_set(): + """ + Should be equal to unset value of None. + """ + test_stream = StreamStubFullRefresh() + + assert test_stream.namespace is None + + +@pytest.mark.parametrize( + "test_input, expected", + [("key", [["key"]]), (["key1", "key2"], [["key1"], ["key2"]]), ([["key1", "key2"], ["key3"]], [["key1", "key2"], ["key3"]])], +) +def test_wrapped_primary_key_various_argument(test_input, expected): + """ + Should always wrap primary key into list of lists. + """ + + wrapped = AsyncStream._wrapped_primary_key(test_input) + + assert wrapped == expected + + +@mock.patch("airbyte_cdk.sources.utils.schema_helpers.ResourceSchemaLoader.get_schema") +def test_get_json_schema_is_cached(mocked_method): + stream = StreamStubFullRefresh() + for i in range(5): + stream.get_json_schema() + assert mocked_method.call_count == 1 diff --git a/airbyte-cdk/python/unit_tests/sources/async_cdk/streams/utils/__init__.py b/airbyte-cdk/python/unit_tests/sources/async_cdk/streams/utils/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/airbyte-cdk/python/unit_tests/sources/async_cdk/test_abstract_source_async.py b/airbyte-cdk/python/unit_tests/sources/async_cdk/test_abstract_source_async.py new file mode 100644 index 000000000000..06b9d6b38b6c --- /dev/null +++ b/airbyte-cdk/python/unit_tests/sources/async_cdk/test_abstract_source_async.py @@ -0,0 +1,1218 @@ +# +# Copyright (c) 2023 Airbyte, Inc., all rights reserved. +# + +import asyncio +import copy +import datetime +import logging +from collections import defaultdict +from typing import Any, Callable, Dict, Iterable, List, Mapping, MutableMapping, Optional, Tuple, Union +from unittest.mock import Mock, call +from unittest.mock import AsyncMock + +import pytest +from airbyte_cdk.models import ( + AirbyteCatalog, + AirbyteConnectionStatus, + AirbyteLogMessage, + AirbyteMessage, + AirbyteRecordMessage, + AirbyteStateBlob, + AirbyteStateMessage, + AirbyteStateType, + AirbyteStream, + AirbyteStreamState, + AirbyteStreamStatus, + AirbyteStreamStatusTraceMessage, + AirbyteTraceMessage, + ConfiguredAirbyteCatalog, + ConfiguredAirbyteStream, + DestinationSyncMode, + Level, + Status, + StreamDescriptor, + SyncMode, + TraceType, +) +from airbyte_cdk.models import Type +from airbyte_cdk.models import Type as MessageType +from airbyte_cdk.sources.async_cdk.abstract_source_async import AsyncAbstractSource +from airbyte_cdk.sources.async_cdk.source_dispatcher import SourceDispatcher +from airbyte_cdk.sources.async_cdk.streams.core_async import AsyncStream +from airbyte_cdk.sources.connector_state_manager import ConnectorStateManager +from airbyte_cdk.sources.message import MessageRepository +from airbyte_cdk.sources.streams import IncrementalMixin +from airbyte_cdk.sources.utils.record_helper import stream_data_to_airbyte_message +from airbyte_cdk.utils.traced_exception import AirbyteTracedException +from pytest import fixture + +logger = logging.getLogger("airbyte") + + +class MockSource(AsyncAbstractSource): + def __init__( + self, + check_lambda: Callable[[], Tuple[bool, Optional[Any]]] = None, + streams: List[AsyncStream] = None, + per_stream: bool = True, + message_repository: MessageRepository = None, + exception_on_missing_stream: bool = True, + ): + self._streams = streams + self.check_lambda = check_lambda + self.per_stream = per_stream + self.exception_on_missing_stream = exception_on_missing_stream + self._message_repository = message_repository + super().__init__() + + async def check_connection(self, logger: logging.Logger, config: Mapping[str, Any]) -> Tuple[bool, Optional[Any]]: + if self.check_lambda: + return self.check_lambda() + return False, "Missing callable." + + async def streams(self, config: Mapping[str, Any]) -> List[AsyncStream]: + if not self._streams: + raise Exception("Stream is not set") + return self._streams + + @property + def raise_exception_on_missing_stream(self) -> bool: + return self.exception_on_missing_stream + + @property + def per_stream_state_enabled(self) -> bool: + return self.per_stream + + @property + def message_repository(self): + return self._message_repository + + +class StreamNoStateMethod(AsyncStream): + name = "managers" + primary_key = None + + async def read_records(self, *args, **kwargs) -> Iterable[Mapping[str, Any]]: + yield {} + + +class MockStreamOverridesStateMethod(AsyncStream, IncrementalMixin): + name = "teams" + primary_key = None + cursor_field = "updated_at" + _cursor_value = "" + start_date = "1984-12-12" + + async def read_records(self, *args, **kwargs) -> Iterable[Mapping[str, Any]]: + yield {} + + @property + def state(self) -> MutableMapping[str, Any]: + return {self.cursor_field: self._cursor_value} if self._cursor_value else {} + + @state.setter + def state(self, value: MutableMapping[str, Any]): + self._cursor_value = value.get(self.cursor_field, self.start_date) + + +MESSAGE_FROM_REPOSITORY = Mock() + + +@fixture +def message_repository(): + message_repository = Mock(spec=MessageRepository) + message_repository.consume_queue.return_value = [message for message in [MESSAGE_FROM_REPOSITORY]] + return message_repository + + +def test_successful_check(): + """Tests that if a source returns TRUE for the connection check the appropriate connectionStatus success message is returned""" + expected = AirbyteConnectionStatus(status=Status.SUCCEEDED) + loop = asyncio.get_event_loop() + assert expected == loop.run_until_complete(MockSource(check_lambda=lambda: (True, None)).check(logger, {})) + + +def test_failed_check(): + """Tests that if a source returns FALSE for the connection check the appropriate connectionStatus failure message is returned""" + expected = AirbyteConnectionStatus(status=Status.FAILED, message="'womp womp'") + loop = asyncio.get_event_loop() + assert expected == loop.run_until_complete(MockSource(check_lambda=lambda: (False, "womp womp")).check(logger, {})) + + +def test_raising_check(mocker): + """Tests that if a source raises an unexpected exception the appropriate connectionStatus failure message is returned.""" + check_lambda = mocker.Mock(side_effect=BaseException("this should fail")) + loop = asyncio.get_event_loop() + with pytest.raises(BaseException): + loop.run_until_complete(MockSource(check_lambda=check_lambda).check(logger, {})) + + +class MockStream(AsyncStream): + def __init__( + self, + inputs_and_mocked_outputs: List[Tuple[Mapping[str, Any], Iterable[Mapping[str, Any]]]] = None, + name: str = None, + ): + self._inputs_and_mocked_outputs = inputs_and_mocked_outputs + self._name = name + + @property + def name(self): + return self._name + + async def read_records(self, error: Exception = None, **kwargs) -> Iterable[Mapping[str, Any]]: # type: ignore + # Remove None values + kwargs = {k: v for k, v in kwargs.items() if v is not None} + output_supplied = False + if self._inputs_and_mocked_outputs: + for _input, output in self._inputs_and_mocked_outputs: + if kwargs == _input: + if isinstance(output, list): + for item in output: + yield item + output_supplied = True + else: + output_supplied = True + yield output + + if not output_supplied: + raise Exception(f"No mocked output supplied for input: {kwargs}. Mocked inputs/outputs: {self._inputs_and_mocked_outputs}") + + @property + def primary_key(self) -> Optional[Union[str, List[str], List[List[str]]]]: + return "pk" + + +class MockStreamWithState(MockStream): + cursor_field = "cursor" + + def __init__(self, inputs_and_mocked_outputs: List[Tuple[Mapping[str, Any], Iterable[Mapping[str, Any]]]], name: str, state=None): + super().__init__(inputs_and_mocked_outputs, name) + self._state = state + + @property + def state(self): + return self._state + + @state.setter + def state(self, value): + pass + + +class MockStreamEmittingAirbyteMessages(MockStreamWithState): + def __init__( + self, inputs_and_mocked_outputs: List[Tuple[Mapping[str, Any], Iterable[AirbyteMessage]]] = None, name: str = None, state=None + ): + super().__init__(inputs_and_mocked_outputs, name, state) + self._inputs_and_mocked_outputs = inputs_and_mocked_outputs + self._name = name + + @property + def name(self): + return self._name + + @property + def primary_key(self) -> Optional[Union[str, List[str], List[List[str]]]]: + return "pk" + + @property + def state(self) -> MutableMapping[str, Any]: + return {self.cursor_field: self._cursor_value} if self._cursor_value else {} + + @state.setter + def state(self, value: MutableMapping[str, Any]): + self._cursor_value = value.get(self.cursor_field, self.start_date) + + +def test_discover(mocker): + """Tests that the appropriate AirbyteCatalog is returned from the discover method""" + airbyte_stream1 = AirbyteStream( + name="1", + json_schema={}, + supported_sync_modes=[SyncMode.full_refresh, SyncMode.incremental], + default_cursor_field=["cursor"], + source_defined_cursor=True, + source_defined_primary_key=[["pk"]], + ) + airbyte_stream2 = AirbyteStream(name="2", json_schema={}, supported_sync_modes=[SyncMode.full_refresh]) + + stream1 = MockStream() + stream2 = MockStream() + mocker.patch.object(stream1, "as_airbyte_stream", return_value=airbyte_stream1) + mocker.patch.object(stream2, "as_airbyte_stream", return_value=airbyte_stream2) + + expected = AirbyteCatalog(streams=[airbyte_stream1, airbyte_stream2]) + src = SourceDispatcher(MockSource(check_lambda=lambda: (True, None), streams=[stream1, stream2])) + + assert expected == src.discover(logger, {}) + + +def test_read_nonexistent_stream_raises_exception(mocker): + """Tests that attempting to sync a stream which the source does not return from the `streams` method raises an exception""" + s1 = MockStream(name="s1") + s2 = MockStream(name="this_stream_doesnt_exist_in_the_source") + + mocker.patch.object(MockStream, "get_json_schema", return_value={}) + + src = SourceDispatcher(MockSource(streams=[s1])) + catalog = ConfiguredAirbyteCatalog(streams=[_configured_stream(s2, SyncMode.full_refresh)]) + with pytest.raises(KeyError): + list(src.read(logger, {}, catalog)) + + +def test_read_nonexistent_stream_without_raises_exception(mocker): + """Tests that attempting to sync a stream which the source does not return from the `streams` method raises an exception""" + s1 = MockStream(name="s1") + s2 = MockStream(name="this_stream_doesnt_exist_in_the_source") + + mocker.patch.object(MockStream, "get_json_schema", return_value={}) + + src = SourceDispatcher(MockSource(streams=[s1], exception_on_missing_stream=False)) + + catalog = ConfiguredAirbyteCatalog(streams=[_configured_stream(s2, SyncMode.full_refresh)]) + messages = list(src.read(logger, {}, catalog)) + + assert messages == [] + + +async def fake_read_records(*args, **kwargs): + for record in [{"a record": "a value"}, {"another record": "another value"}]: + yield record + + +def test_read_stream_emits_repository_message_before_record(mocker, message_repository): + stream = MockStream(name="my_stream") + mocker.patch.object(MockStream, "get_json_schema", return_value={}) + mocker.patch.object(MockStream, "read_records", fake_read_records) + message_repository.consume_queue.side_effect = [[message for message in [MESSAGE_FROM_REPOSITORY]], []] + + source = SourceDispatcher(MockSource(streams=[stream], message_repository=message_repository)) + + messages = list(source.read(logger, {}, ConfiguredAirbyteCatalog(streams=[_configured_stream(stream, SyncMode.full_refresh)]))) + + assert messages.count(MESSAGE_FROM_REPOSITORY) == 1 + record_messages = (message for message in messages if message.type == Type.RECORD) + assert all(messages.index(MESSAGE_FROM_REPOSITORY) < messages.index(record) for record in record_messages) + + +def test_read_stream_emits_repository_message_on_error(mocker, message_repository): + stream = MockStream(name="my_stream") + mocker.patch.object(MockStream, "get_json_schema", return_value={}) + mocker.patch.object(MockStream, "read_records", side_effect=RuntimeError("error")) + message_repository.consume_queue.return_value = [message for message in [MESSAGE_FROM_REPOSITORY]] + + source = SourceDispatcher(MockSource(streams=[stream], message_repository=message_repository)) + with pytest.raises(RuntimeError): + messages = list(source.read(logger, {}, ConfiguredAirbyteCatalog(streams=[_configured_stream(stream, SyncMode.full_refresh)]))) + assert MESSAGE_FROM_REPOSITORY in messages + + +async def read_records_with_error(*args, **kwargs): + if False: + yield + raise RuntimeError("oh no!") + + +def test_read_stream_with_error_gets_display_message(mocker): + stream = MockStream(name="my_stream") + mocker.patch.object(MockStream, "get_json_schema", return_value={}) + stream.read_records = read_records_with_error + + source = SourceDispatcher(MockSource(streams=[stream])) + catalog = ConfiguredAirbyteCatalog(streams=[_configured_stream(stream, SyncMode.full_refresh)]) + + # without get_error_display_message + with pytest.raises(RuntimeError, match="oh no!"): + list(source.read(logger, {}, catalog)) + + mocker.patch.object(MockStream, "get_error_display_message", return_value="my message") + + with pytest.raises(AirbyteTracedException, match="oh no!") as exc: + list(source.read(logger, {}, catalog)) + assert exc.value.message == "my message" + + +GLOBAL_EMITTED_AT = 1 + + +def _as_record(stream: str, data: Dict[str, Any]) -> AirbyteMessage: + return AirbyteMessage( + type=Type.RECORD, + record=AirbyteRecordMessage(stream=stream, data=data, emitted_at=GLOBAL_EMITTED_AT), + ) + + +def _as_records(stream: str, data: List[Dict[str, Any]]) -> List[AirbyteMessage]: + return [_as_record(stream, datum) for datum in data] + + +def _as_stream_status(stream: str, status: AirbyteStreamStatus) -> AirbyteMessage: + trace_message = AirbyteTraceMessage( + emitted_at=datetime.datetime.now().timestamp() * 1000.0, + type=TraceType.STREAM_STATUS, + stream_status=AirbyteStreamStatusTraceMessage( + stream_descriptor=StreamDescriptor(name=stream), + status=status, + ), + ) + + return AirbyteMessage(type=MessageType.TRACE, trace=trace_message) + + +def _as_state(state_data: Dict[str, Any], stream_name: str = "", per_stream_state: Dict[str, Any] = None): + if per_stream_state: + return AirbyteMessage( + type=Type.STATE, + state=AirbyteStateMessage( + type=AirbyteStateType.STREAM, + stream=AirbyteStreamState( + stream_descriptor=StreamDescriptor(name=stream_name), stream_state=AirbyteStateBlob.parse_obj(per_stream_state) + ), + data=state_data, + ), + ) + return AirbyteMessage(type=Type.STATE, state=AirbyteStateMessage(data=state_data)) + + +def _configured_stream(stream: AsyncStream, sync_mode: SyncMode): + return ConfiguredAirbyteStream( + stream=stream.as_airbyte_stream(), + sync_mode=sync_mode, + destination_sync_mode=DestinationSyncMode.overwrite, + ) + + +def _fix_emitted_at(messages: List[AirbyteMessage]) -> List[AirbyteMessage]: + for msg in messages: + if msg.type == Type.RECORD and msg.record: + msg.record.emitted_at = GLOBAL_EMITTED_AT + if msg.type == Type.TRACE and msg.trace: + msg.trace.emitted_at = GLOBAL_EMITTED_AT + return messages + + +def test_valid_full_refresh_read_no_slices(mocker): + """Tests that running a full refresh sync on streams which don't specify slices produces the expected AirbyteMessages""" + stream_output = [{"k1": "v1"}, {"k2": "v2"}] + s1 = MockStream([({"sync_mode": SyncMode.full_refresh}, stream_output)], name="s1") + s2 = MockStream([({"sync_mode": SyncMode.full_refresh}, stream_output)], name="s2") + + mocker.patch.object(MockStream, "get_json_schema", return_value={}) + + src = SourceDispatcher(MockSource(streams=[s1, s2])) + catalog = ConfiguredAirbyteCatalog( + streams=[ + _configured_stream(s1, SyncMode.full_refresh), + _configured_stream(s2, SyncMode.full_refresh), + ] + ) + + expected = _fix_emitted_at( + [ + _as_stream_status("s1", AirbyteStreamStatus.STARTED), + _as_stream_status("s1", AirbyteStreamStatus.RUNNING), + *_as_records("s1", stream_output), + _as_stream_status("s1", AirbyteStreamStatus.COMPLETE), + _as_stream_status("s2", AirbyteStreamStatus.STARTED), + _as_stream_status("s2", AirbyteStreamStatus.RUNNING), + *_as_records("s2", stream_output), + _as_stream_status("s2", AirbyteStreamStatus.COMPLETE), + ] + ) + messages = _fix_emitted_at(list(src.read(logger, {}, catalog))) + + assert expected == messages + + +def test_valid_full_refresh_read_with_slices(mocker): + """Tests that running a full refresh sync on streams which use slices produces the expected AirbyteMessages""" + # When attempting to sync a slice, just output that slice as a record + slices = [{"1": "1"}, {"2": "2"}] + s1 = MockStream( + [({"sync_mode": SyncMode.full_refresh, "stream_slice": s}, [s]) for s in slices], + name="s1", + ) + s2 = MockStream( + [({"sync_mode": SyncMode.full_refresh, "stream_slice": s}, [s]) for s in slices], + name="s2", + ) + + async def _fake_stream_slices(*args, **kwargs): + for _slice in slices: + yield _slice + + mocker.patch.object(MockStream, "get_json_schema", return_value={}) + mocker.patch.object(MockStream, "stream_slices", _fake_stream_slices) + + src = SourceDispatcher(MockSource(streams=[s1, s2])) + catalog = ConfiguredAirbyteCatalog( + streams=[ + _configured_stream(s1, SyncMode.full_refresh), + _configured_stream(s2, SyncMode.full_refresh), + ] + ) + + expected = _fix_emitted_at( + [ + _as_stream_status("s1", AirbyteStreamStatus.STARTED), + _as_stream_status("s1", AirbyteStreamStatus.RUNNING), + *_as_records("s1", slices), + _as_stream_status("s1", AirbyteStreamStatus.COMPLETE), + _as_stream_status("s2", AirbyteStreamStatus.STARTED), + _as_stream_status("s2", AirbyteStreamStatus.RUNNING), + *_as_records("s2", slices), + _as_stream_status("s2", AirbyteStreamStatus.COMPLETE), + ] + ) + + messages = _fix_emitted_at(list(src.read(logger, {}, catalog))) + + assert expected == messages + + +@pytest.mark.parametrize( + "slices", + [ + [{"1": "1"}, {"2": "2"}], + [ + {"date": datetime.date(year=2023, month=1, day=1)}, + {"date": datetime.date(year=2023, month=1, day=1)}, + ] + ], +) +def test_read_full_refresh_with_slices_sends_slice_messages(mocker, slices): + """Given the logger is debug and a full refresh, AirbyteMessages are sent for slices""" + debug_logger = logging.getLogger("airbyte.debug") + debug_logger.setLevel(logging.DEBUG) + stream = MockStream( + [({"sync_mode": SyncMode.full_refresh, "stream_slice": s}, [s]) for s in slices], + name="s1", + ) + async def _fake_stream_slices(*args, **kwargs): + for _slice in slices: + yield _slice + + mocker.patch.object(MockStream, "get_json_schema", return_value={}) + mocker.patch.object(MockStream, "stream_slices", _fake_stream_slices) + + src = SourceDispatcher(MockSource(streams=[stream])) + catalog = ConfiguredAirbyteCatalog( + streams=[ + _configured_stream(stream, SyncMode.full_refresh), + ] + ) + + messages = src.read(debug_logger, {}, catalog) + + assert 2 == len(list(filter(lambda message: message.log and message.log.message.startswith("slice:"), messages))) + + +def test_read_incremental_with_slices_sends_slice_messages(mocker): + """Given the logger is debug and a incremental, AirbyteMessages are sent for slices""" + debug_logger = logging.getLogger("airbyte.debug") + debug_logger.setLevel(logging.DEBUG) + slices = [{"1": "1"}, {"2": "2"}] + stream = MockStream( + [({"sync_mode": SyncMode.incremental, "stream_slice": s, "stream_state": {}}, [s]) for s in slices], + name="s1", + ) + async def _fake_stream_slices(*args, **kwargs): + for _slice in slices: + yield _slice + + MockStream.supports_incremental = mocker.PropertyMock(return_value=True) + mocker.patch.object(MockStream, "get_json_schema", return_value={}) + mocker.patch.object(MockStream, "stream_slices", _fake_stream_slices) + + src = SourceDispatcher(MockSource(streams=[stream])) + catalog = ConfiguredAirbyteCatalog( + streams=[ + _configured_stream(stream, SyncMode.incremental), + ] + ) + + messages = src.read(debug_logger, {}, catalog) + + assert 2 == len(list(filter(lambda message: message.log and message.log.message.startswith("slice:"), messages))) + + +class TestIncrementalRead: + @pytest.mark.parametrize( + "use_legacy", + [ + pytest.param(True, id="test_incoming_stream_state_as_legacy_format"), + pytest.param(False, id="test_incoming_stream_state_as_per_stream_format"), + ], + ) + @pytest.mark.parametrize( + "per_stream_enabled", + [ + pytest.param(True, id="test_source_emits_state_as_per_stream_format"), + pytest.param(False, id="test_source_emits_state_as_per_stream_format"), + ], + ) + def test_with_state_attribute(self, mocker, use_legacy, per_stream_enabled): + """Test correct state passing for the streams that have a state attribute""" + stream_output = [{"k1": "v1"}, {"k2": "v2"}] + old_state = {"cursor": "old_value"} + if use_legacy: + input_state = {"s1": old_state} + else: + input_state = [ + AirbyteStateMessage( + type=AirbyteStateType.STREAM, + stream=AirbyteStreamState( + stream_descriptor=StreamDescriptor(name="s1"), stream_state=AirbyteStateBlob.parse_obj(old_state) + ), + ), + ] + new_state_from_connector = {"cursor": "new_value"} + + stream_1 = MockStreamWithState( + [ + ( + {"sync_mode": SyncMode.incremental, "stream_state": old_state}, + stream_output, + ) + ], + name="s1", + ) + stream_2 = MockStreamWithState( + [({"sync_mode": SyncMode.incremental, "stream_state": {}}, stream_output)], + name="s2", + ) + mocker.patch.object(MockStreamWithState, "get_updated_state", return_value={}) + state_property = mocker.patch.object( + MockStreamWithState, + "state", + new_callable=mocker.PropertyMock, + return_value=new_state_from_connector, + ) + mocker.patch.object(MockStreamWithState, "get_json_schema", return_value={}) + src = SourceDispatcher(MockSource(streams=[stream_1, stream_2], per_stream=per_stream_enabled)) + catalog = ConfiguredAirbyteCatalog( + streams=[ + _configured_stream(stream_1, SyncMode.incremental), + _configured_stream(stream_2, SyncMode.incremental), + ] + ) + + expected = _fix_emitted_at( + [ + _as_stream_status("s1", AirbyteStreamStatus.STARTED), + _as_stream_status("s1", AirbyteStreamStatus.RUNNING), + _as_record("s1", stream_output[0]), + _as_record("s1", stream_output[1]), + _as_state({"s1": new_state_from_connector}, "s1", new_state_from_connector) + if per_stream_enabled + else _as_state({"s1": new_state_from_connector}), + _as_stream_status("s1", AirbyteStreamStatus.COMPLETE), + _as_stream_status("s2", AirbyteStreamStatus.STARTED), + _as_stream_status("s2", AirbyteStreamStatus.RUNNING), + _as_record("s2", stream_output[0]), + _as_record("s2", stream_output[1]), + _as_state({"s1": new_state_from_connector, "s2": new_state_from_connector}, "s2", new_state_from_connector) + if per_stream_enabled + else _as_state({"s1": new_state_from_connector, "s2": new_state_from_connector}), + _as_stream_status("s2", AirbyteStreamStatus.COMPLETE), + ] + ) + messages = _fix_emitted_at(list(src.read(logger, {}, catalog, state=input_state))) + + assert messages == expected + assert state_property.mock_calls == [ + call(old_state), # set state for s1 + call(), # get state in the end of slice for s1 + call(), # get state in the end of slice for s2 + ] + + @pytest.mark.parametrize( + "use_legacy", + [ + pytest.param(True, id="test_incoming_stream_state_as_legacy_format"), + pytest.param(False, id="test_incoming_stream_state_as_per_stream_format"), + ], + ) + @pytest.mark.parametrize( + "per_stream_enabled", + [ + pytest.param(True, id="test_source_emits_state_as_per_stream_format"), + pytest.param(False, id="test_source_emits_state_as_per_stream_format"), + ], + ) + def test_with_checkpoint_interval(self, mocker, use_legacy, per_stream_enabled): + """Tests that an incremental read which doesn't specify a checkpoint interval outputs a STATE message + after reading N records within a stream. + """ + if use_legacy: + input_state = defaultdict(dict) + else: + input_state = [] + stream_output = [{"k1": "v1"}, {"k2": "v2"}] + + stream_1 = MockStream( + [({"sync_mode": SyncMode.incremental, "stream_state": {}}, stream_output)], + name="s1", + ) + stream_2 = MockStream( + [({"sync_mode": SyncMode.incremental, "stream_state": {}}, stream_output)], + name="s2", + ) + state = {"cursor": "value"} + mocker.patch.object(MockStream, "get_updated_state", return_value=state) + mocker.patch.object(MockStream, "supports_incremental", return_value=True) + mocker.patch.object(MockStream, "get_json_schema", return_value={}) + # Tell the source to output one state message per record + mocker.patch.object( + MockStream, + "state_checkpoint_interval", + new_callable=mocker.PropertyMock, + return_value=1, + ) + + src = SourceDispatcher(MockSource(streams=[stream_1, stream_2], per_stream=per_stream_enabled)) + catalog = ConfiguredAirbyteCatalog( + streams=[ + _configured_stream(stream_1, SyncMode.incremental), + _configured_stream(stream_2, SyncMode.incremental), + ] + ) + + expected = _fix_emitted_at( + [ + _as_stream_status("s1", AirbyteStreamStatus.STARTED), + _as_stream_status("s1", AirbyteStreamStatus.RUNNING), + _as_record("s1", stream_output[0]), + _as_state({"s1": state}, "s1", state) if per_stream_enabled else _as_state({"s1": state}), + _as_record("s1", stream_output[1]), + _as_state({"s1": state}, "s1", state) if per_stream_enabled else _as_state({"s1": state}), + _as_state({"s1": state}, "s1", state) if per_stream_enabled else _as_state({"s1": state}), + _as_stream_status("s1", AirbyteStreamStatus.COMPLETE), + _as_stream_status("s2", AirbyteStreamStatus.STARTED), + _as_stream_status("s2", AirbyteStreamStatus.RUNNING), + _as_record("s2", stream_output[0]), + _as_state({"s1": state, "s2": state}, "s2", state) if per_stream_enabled else _as_state({"s1": state, "s2": state}), + _as_record("s2", stream_output[1]), + _as_state({"s1": state, "s2": state}, "s2", state) if per_stream_enabled else _as_state({"s1": state, "s2": state}), + _as_state({"s1": state, "s2": state}, "s2", state) if per_stream_enabled else _as_state({"s1": state, "s2": state}), + _as_stream_status("s2", AirbyteStreamStatus.COMPLETE), + ] + ) + messages = _fix_emitted_at(list(src.read(logger, {}, catalog, state=input_state))) + + assert expected == messages + + @pytest.mark.parametrize( + "use_legacy", + [ + pytest.param(True, id="test_incoming_stream_state_as_legacy_format"), + pytest.param(False, id="test_incoming_stream_state_as_per_stream_format"), + ], + ) + @pytest.mark.parametrize( + "per_stream_enabled", + [ + pytest.param(True, id="test_source_emits_state_as_per_stream_format"), + pytest.param(False, id="test_source_emits_state_as_per_stream_format"), + ], + ) + def test_with_no_interval(self, mocker, use_legacy, per_stream_enabled): + """Tests that an incremental read which doesn't specify a checkpoint interval outputs + a STATE message only after fully reading the stream and does not output any STATE messages during syncing the stream. + """ + if use_legacy: + input_state = defaultdict(dict) + else: + input_state = [] + stream_output = [{"k1": "v1"}, {"k2": "v2"}] + + stream_1 = MockStream( + [({"sync_mode": SyncMode.incremental, "stream_state": {}}, stream_output)], + name="s1", + ) + stream_2 = MockStream( + [({"sync_mode": SyncMode.incremental, "stream_state": {}}, stream_output)], + name="s2", + ) + state = {"cursor": "value"} + mocker.patch.object(MockStream, "get_updated_state", return_value=state) + mocker.patch.object(MockStream, "supports_incremental", return_value=True) + mocker.patch.object(MockStream, "get_json_schema", return_value={}) + + src = SourceDispatcher(MockSource(streams=[stream_1, stream_2], per_stream=per_stream_enabled)) + catalog = ConfiguredAirbyteCatalog( + streams=[ + _configured_stream(stream_1, SyncMode.incremental), + _configured_stream(stream_2, SyncMode.incremental), + ] + ) + + expected = _fix_emitted_at( + [ + _as_stream_status("s1", AirbyteStreamStatus.STARTED), + _as_stream_status("s1", AirbyteStreamStatus.RUNNING), + *_as_records("s1", stream_output), + _as_state({"s1": state}, "s1", state) if per_stream_enabled else _as_state({"s1": state}), + _as_stream_status("s1", AirbyteStreamStatus.COMPLETE), + _as_stream_status("s2", AirbyteStreamStatus.STARTED), + _as_stream_status("s2", AirbyteStreamStatus.RUNNING), + *_as_records("s2", stream_output), + _as_state({"s1": state, "s2": state}, "s2", state) if per_stream_enabled else _as_state({"s1": state, "s2": state}), + _as_stream_status("s2", AirbyteStreamStatus.COMPLETE), + ] + ) + + messages = _fix_emitted_at(list(src.read(logger, {}, catalog, state=input_state))) + + assert expected == messages + + @pytest.mark.parametrize( + "use_legacy", + [ + pytest.param(True, id="test_incoming_stream_state_as_legacy_format"), + pytest.param(False, id="test_incoming_stream_state_as_per_stream_format"), + ], + ) + @pytest.mark.parametrize( + "per_stream_enabled", + [ + pytest.param(True, id="test_source_emits_state_as_per_stream_format"), + pytest.param(False, id="test_source_emits_state_as_per_stream_format"), + ], + ) + def test_with_slices(self, mocker, use_legacy, per_stream_enabled): + """Tests that an incremental read which uses slices outputs each record in the slice followed by a STATE message, for each slice""" + if use_legacy: + input_state = defaultdict(dict) + else: + input_state = [] + slices = [{"1": "1"}, {"2": "2"}] + stream_output = [{"k1": "v1"}, {"k2": "v2"}, {"k3": "v3"}] + + stream_1 = MockStream( + [ + ( + { + "sync_mode": SyncMode.incremental, + "stream_slice": s, + "stream_state": mocker.ANY, + }, + stream_output, + ) + for s in slices + ], + name="s1", + ) + stream_2 = MockStream( + [ + ( + { + "sync_mode": SyncMode.incremental, + "stream_slice": s, + "stream_state": mocker.ANY, + }, + stream_output, + ) + for s in slices + ], + name="s2", + ) + + async def _fake_stream_slices(*args, **kwargs): + for _slice in slices: + yield _slice + + state = {"cursor": "value"} + mocker.patch.object(MockStream, "get_updated_state", return_value=state) + mocker.patch.object(MockStream, "supports_incremental", return_value=True) + mocker.patch.object(MockStream, "get_json_schema", return_value={}) + mocker.patch.object(MockStream, "stream_slices", _fake_stream_slices) + + src = SourceDispatcher(MockSource(streams=[stream_1, stream_2], per_stream=per_stream_enabled)) + catalog = ConfiguredAirbyteCatalog( + streams=[ + _configured_stream(stream_1, SyncMode.incremental), + _configured_stream(stream_2, SyncMode.incremental), + ] + ) + + expected = _fix_emitted_at( + [ + _as_stream_status("s1", AirbyteStreamStatus.STARTED), + _as_stream_status("s1", AirbyteStreamStatus.RUNNING), + # stream 1 slice 1 + *_as_records("s1", stream_output), + _as_state({"s1": state}, "s1", state) if per_stream_enabled else _as_state({"s1": state}), + # stream 1 slice 2 + *_as_records("s1", stream_output), + _as_state({"s1": state}, "s1", state) if per_stream_enabled else _as_state({"s1": state}), + _as_stream_status("s1", AirbyteStreamStatus.COMPLETE), + _as_stream_status("s2", AirbyteStreamStatus.STARTED), + _as_stream_status("s2", AirbyteStreamStatus.RUNNING), + # stream 2 slice 1 + *_as_records("s2", stream_output), + _as_state({"s1": state, "s2": state}, "s2", state) if per_stream_enabled else _as_state({"s1": state, "s2": state}), + # stream 2 slice 2 + *_as_records("s2", stream_output), + _as_state({"s1": state, "s2": state}, "s2", state) if per_stream_enabled else _as_state({"s1": state, "s2": state}), + _as_stream_status("s2", AirbyteStreamStatus.COMPLETE), + ] + ) + + messages = _fix_emitted_at(list(src.read(logger, {}, catalog, state=input_state))) + + assert expected == messages + + @pytest.mark.parametrize( + "use_legacy", + [ + pytest.param(True, id="test_incoming_stream_state_as_legacy_format"), + pytest.param(False, id="test_incoming_stream_state_as_per_stream_format"), + ], + ) + @pytest.mark.parametrize( + "per_stream_enabled", + [ + pytest.param(True, id="test_source_emits_state_as_per_stream_format"), + pytest.param(False, id="test_source_emits_state_as_per_stream_format"), + ], + ) + @pytest.mark.parametrize("slices", [pytest.param([], id="test_slices_as_list"), pytest.param(iter([]), id="test_slices_as_iterator")]) + def test_no_slices(self, mocker, use_legacy, per_stream_enabled, slices): + """ + Tests that an incremental read returns at least one state messages even if no records were read: + 1. outputs a state message after reading the entire stream + """ + if use_legacy: + input_state = defaultdict(dict) + else: + input_state = [] + + stream_output = [{"k1": "v1"}, {"k2": "v2"}, {"k3": "v3"}] + state = {"cursor": "value"} + stream_1 = MockStreamWithState( + [ + ( + { + "sync_mode": SyncMode.incremental, + "stream_slice": s, + "stream_state": mocker.ANY, + }, + stream_output, + ) + for s in slices + ], + name="s1", + state=state, + ) + stream_2 = MockStreamWithState( + [ + ( + { + "sync_mode": SyncMode.incremental, + "stream_slice": s, + "stream_state": mocker.ANY, + }, + stream_output, + ) + for s in slices + ], + name="s2", + state=state, + ) + + async def _fake_stream_slices(*args, **kwargs): + for _slice in slices: + yield _slice + + mocker.patch.object(MockStreamWithState, "supports_incremental", return_value=True) + mocker.patch.object(MockStreamWithState, "get_json_schema", return_value={}) + mocker.patch.object(MockStreamWithState, "stream_slices", _fake_stream_slices) + mocker.patch.object( + MockStreamWithState, + "state_checkpoint_interval", + new_callable=mocker.PropertyMock, + return_value=2, + ) + + src = SourceDispatcher(MockSource(streams=[stream_1, stream_2], per_stream=per_stream_enabled)) + catalog = ConfiguredAirbyteCatalog( + streams=[ + _configured_stream(stream_1, SyncMode.incremental), + _configured_stream(stream_2, SyncMode.incremental), + ] + ) + + expected = _fix_emitted_at( + [ + _as_stream_status("s1", AirbyteStreamStatus.STARTED), + _as_state({"s1": state}, "s1", state) if per_stream_enabled else _as_state({"s1": state}), + _as_stream_status("s1", AirbyteStreamStatus.COMPLETE), + _as_stream_status("s2", AirbyteStreamStatus.STARTED), + _as_state({"s1": state, "s2": state}, "s2", state) if per_stream_enabled else _as_state({"s1": state, "s2": state}), + _as_stream_status("s2", AirbyteStreamStatus.COMPLETE), + ] + ) + + messages = _fix_emitted_at(list(src.read(logger, {}, catalog, state=input_state))) + + assert expected == messages + + @pytest.mark.parametrize( + "use_legacy", + [ + pytest.param(True, id="test_incoming_stream_state_as_legacy_format"), + pytest.param(False, id="test_incoming_stream_state_as_per_stream_format"), + ], + ) + @pytest.mark.parametrize( + "per_stream_enabled", + [ + pytest.param(True, id="test_source_emits_state_as_per_stream_format"), + pytest.param(False, id="test_source_emits_state_as_per_stream_format"), + ], + ) + def test_with_slices_and_interval(self, mocker, use_legacy, per_stream_enabled): + """ + Tests that an incremental read which uses slices and a checkpoint interval: + 1. outputs all records + 2. outputs a state message every N records (N=checkpoint_interval) + 3. outputs a state message after reading the entire slice + """ + if use_legacy: + input_state = defaultdict(dict) + else: + input_state = [] + slices = [{"1": "1"}, {"2": "2"}] + stream_output = [{"k1": "v1"}, {"k2": "v2"}, {"k3": "v3"}] + stream_1 = MockStream( + [ + ( + { + "sync_mode": SyncMode.incremental, + "stream_slice": s, + "stream_state": mocker.ANY, + }, + stream_output, + ) + for s in slices + ], + name="s1", + ) + stream_2 = MockStream( + [ + ( + { + "sync_mode": SyncMode.incremental, + "stream_slice": s, + "stream_state": mocker.ANY, + }, + stream_output, + ) + for s in slices + ], + name="s2", + ) + async def _fake_stream_slices(*args, **kwargs): + for _slice in slices: + yield _slice + + state = {"cursor": "value"} + mocker.patch.object(MockStream, "get_updated_state", return_value=state) + mocker.patch.object(MockStream, "supports_incremental", return_value=True) + mocker.patch.object(MockStream, "get_json_schema", return_value={}) + mocker.patch.object(MockStream, "stream_slices", _fake_stream_slices) + mocker.patch.object( + MockStream, + "state_checkpoint_interval", + new_callable=mocker.PropertyMock, + return_value=2, + ) + + src = SourceDispatcher(MockSource(streams=[stream_1, stream_2], per_stream=per_stream_enabled)) + catalog = ConfiguredAirbyteCatalog( + streams=[ + _configured_stream(stream_1, SyncMode.incremental), + _configured_stream(stream_2, SyncMode.incremental), + ] + ) + + expected = _fix_emitted_at( + [ + # stream 1 slice 1 + _as_stream_status("s1", AirbyteStreamStatus.STARTED), + _as_stream_status("s1", AirbyteStreamStatus.RUNNING), + _as_record("s1", stream_output[0]), + _as_record("s1", stream_output[1]), + _as_state({"s1": state}, "s1", state) if per_stream_enabled else _as_state({"s1": state}), + _as_record("s1", stream_output[2]), + _as_state({"s1": state}, "s1", state) if per_stream_enabled else _as_state({"s1": state}), + # stream 1 slice 2 + _as_record("s1", stream_output[0]), + _as_state({"s1": state}, "s1", state) if per_stream_enabled else _as_state({"s1": state}), + _as_record("s1", stream_output[1]), + _as_record("s1", stream_output[2]), + _as_state({"s1": state}, "s1", state) if per_stream_enabled else _as_state({"s1": state}), + _as_state({"s1": state}, "s1", state) if per_stream_enabled else _as_state({"s1": state}), + _as_stream_status("s1", AirbyteStreamStatus.COMPLETE), + # stream 2 slice 1 + _as_stream_status("s2", AirbyteStreamStatus.STARTED), + _as_stream_status("s2", AirbyteStreamStatus.RUNNING), + _as_record("s2", stream_output[0]), + _as_record("s2", stream_output[1]), + _as_state({"s1": state, "s2": state}, "s2", state) if per_stream_enabled else _as_state({"s1": state, "s2": state}), + _as_record("s2", stream_output[2]), + _as_state({"s1": state, "s2": state}, "s2", state) if per_stream_enabled else _as_state({"s1": state, "s2": state}), + # stream 2 slice 2 + _as_record("s2", stream_output[0]), + _as_state({"s1": state, "s2": state}, "s2", state) if per_stream_enabled else _as_state({"s1": state, "s2": state}), + _as_record("s2", stream_output[1]), + _as_record("s2", stream_output[2]), + _as_state({"s1": state, "s2": state}, "s2", state) if per_stream_enabled else _as_state({"s1": state, "s2": state}), + _as_state({"s1": state, "s2": state}, "s2", state) if per_stream_enabled else _as_state({"s1": state, "s2": state}), + _as_stream_status("s2", AirbyteStreamStatus.COMPLETE), + ] + ) + + messages = _fix_emitted_at(list(src.read(logger, {}, catalog, state=input_state))) + + assert messages == expected + + @pytest.mark.parametrize( + "per_stream_enabled", + [ + pytest.param(False, id="test_source_emits_state_as_per_stream_format"), + ], + ) + def test_emit_non_records(self, mocker, per_stream_enabled): + """ + Tests that an incremental read which uses slices and a checkpoint interval: + 1. outputs all records + 2. outputs a state message every N records (N=checkpoint_interval) + 3. outputs a state message after reading the entire slice + """ + + input_state = [] + slices = [{"1": "1"}, {"2": "2"}] + stream_output = [ + {"k1": "v1"}, + AirbyteLogMessage(level=Level.INFO, message="HELLO"), + {"k2": "v2"}, + {"k3": "v3"}, + ] + stream_1 = MockStreamEmittingAirbyteMessages( + [ + ( + { + "sync_mode": SyncMode.incremental, + "stream_slice": s, + "stream_state": mocker.ANY, + }, + stream_output, + ) + for s in slices + ], + name="s1", + state=copy.deepcopy(input_state), + ) + stream_2 = MockStreamEmittingAirbyteMessages( + [ + ( + { + "sync_mode": SyncMode.incremental, + "stream_slice": s, + "stream_state": mocker.ANY, + }, + stream_output, + ) + for s in slices + ], + name="s2", + state=copy.deepcopy(input_state), + ) + async def _fake_stream_slices(*args, **kwargs): + for _slice in slices: + yield _slice + + state = {"cursor": "value"} + mocker.patch.object(MockStream, "get_updated_state", return_value=state) + mocker.patch.object(MockStream, "supports_incremental", return_value=True) + mocker.patch.object(MockStream, "get_json_schema", return_value={}) + mocker.patch.object(MockStream, "stream_slices", _fake_stream_slices) + mocker.patch.object( + MockStream, + "state_checkpoint_interval", + new_callable=mocker.PropertyMock, + return_value=2, + ) + + src = SourceDispatcher(MockSource(streams=[stream_1, stream_2], per_stream=per_stream_enabled)) + catalog = ConfiguredAirbyteCatalog( + streams=[ + _configured_stream(stream_1, SyncMode.incremental), + _configured_stream(stream_2, SyncMode.incremental), + ] + ) + + expected = _fix_emitted_at( + [ + _as_stream_status("s1", AirbyteStreamStatus.STARTED), + _as_stream_status("s1", AirbyteStreamStatus.RUNNING), + # stream 1 slice 1 + stream_data_to_airbyte_message("s1", stream_output[0]), + stream_data_to_airbyte_message("s1", stream_output[1]), + stream_data_to_airbyte_message("s1", stream_output[2]), + _as_state({"s1": state}, "s1", state) if per_stream_enabled else _as_state({"s1": state}), + stream_data_to_airbyte_message("s1", stream_output[3]), + _as_state({"s1": state}, "s1", state) if per_stream_enabled else _as_state({"s1": state}), + # stream 1 slice 2 + stream_data_to_airbyte_message("s1", stream_output[0]), + _as_state({"s1": state}, "s1", state) if per_stream_enabled else _as_state({"s1": state}), + stream_data_to_airbyte_message("s1", stream_output[1]), + stream_data_to_airbyte_message("s1", stream_output[2]), + stream_data_to_airbyte_message("s1", stream_output[3]), + _as_state({"s1": state}, "s1", state) if per_stream_enabled else _as_state({"s1": state}), + _as_state({"s1": state}, "s1", state) if per_stream_enabled else _as_state({"s1": state}), + _as_stream_status("s1", AirbyteStreamStatus.COMPLETE), + # stream 2 slice 1 + _as_stream_status("s2", AirbyteStreamStatus.STARTED), + _as_stream_status("s2", AirbyteStreamStatus.RUNNING), + stream_data_to_airbyte_message("s2", stream_output[0]), + stream_data_to_airbyte_message("s2", stream_output[1]), + stream_data_to_airbyte_message("s2", stream_output[2]), + _as_state({"s1": state, "s2": state}, "s2", state) if per_stream_enabled else _as_state({"s1": state, "s2": state}), + stream_data_to_airbyte_message("s2", stream_output[3]), + _as_state({"s1": state, "s2": state}, "s2", state) if per_stream_enabled else _as_state({"s1": state, "s2": state}), + # stream 2 slice 2 + stream_data_to_airbyte_message("s2", stream_output[0]), + _as_state({"s1": state, "s2": state}, "s2", state) if per_stream_enabled else _as_state({"s1": state, "s2": state}), + stream_data_to_airbyte_message("s2", stream_output[1]), + stream_data_to_airbyte_message("s2", stream_output[2]), + stream_data_to_airbyte_message("s2", stream_output[3]), + _as_state({"s1": state, "s2": state}, "s2", state) if per_stream_enabled else _as_state({"s1": state, "s2": state}), + _as_state({"s1": state, "s2": state}, "s2", state) if per_stream_enabled else _as_state({"s1": state, "s2": state}), + _as_stream_status("s2", AirbyteStreamStatus.COMPLETE), + ] + ) + + messages = _fix_emitted_at(list(src.read(logger, {}, catalog, state=input_state))) + + assert messages == expected + + +def test_checkpoint_state_from_stream_instance(): + teams_stream = MockStreamOverridesStateMethod() + managers_stream = StreamNoStateMethod() + state_manager = ConnectorStateManager({"teams": teams_stream, "managers": managers_stream}, []) + + # The stream_state passed to checkpoint_state() should be ignored since stream implements state function + teams_stream.state = {"updated_at": "2022-09-11"} + actual_message = teams_stream._checkpoint_state({"ignored": "state"}, state_manager, True) + assert actual_message == _as_state({"teams": {"updated_at": "2022-09-11"}}, "teams", {"updated_at": "2022-09-11"}) + + # The stream_state passed to checkpoint_state() should be used since the stream does not implement state function + actual_message = managers_stream._checkpoint_state({"updated": "expected_here"}, state_manager, True) + assert actual_message == _as_state( + {"teams": {"updated_at": "2022-09-11"}, "managers": {"updated": "expected_here"}}, "managers", {"updated": "expected_here"} + ) diff --git a/airbyte-cdk/python/unit_tests/sources/file_based/scenarios/avro_scenarios.py b/airbyte-cdk/python/unit_tests/sources/file_based/scenarios/avro_scenarios.py index f1cdac5838b2..5a2f9da88bf3 100644 --- a/airbyte-cdk/python/unit_tests/sources/file_based/scenarios/avro_scenarios.py +++ b/airbyte-cdk/python/unit_tests/sources/file_based/scenarios/avro_scenarios.py @@ -7,7 +7,7 @@ from unit_tests.sources.file_based.in_memory_files_source import TemporaryAvroFilesStreamReader from unit_tests.sources.file_based.scenarios.file_based_source_builder import FileBasedSourceBuilder -from unit_tests.sources.file_based.scenarios.scenario_builder import TestScenarioBuilder +from unit_tests.sources.scenario_based.scenario_builder import TestScenarioBuilder _single_avro_file = { "a.avro": { diff --git a/airbyte-cdk/python/unit_tests/sources/file_based/scenarios/check_scenarios.py b/airbyte-cdk/python/unit_tests/sources/file_based/scenarios/check_scenarios.py index 26136d9cf025..8715dc256d3f 100644 --- a/airbyte-cdk/python/unit_tests/sources/file_based/scenarios/check_scenarios.py +++ b/airbyte-cdk/python/unit_tests/sources/file_based/scenarios/check_scenarios.py @@ -9,7 +9,7 @@ TestErrorOpenFileInMemoryFilesStreamReader, ) from unit_tests.sources.file_based.scenarios.file_based_source_builder import FileBasedSourceBuilder -from unit_tests.sources.file_based.scenarios.scenario_builder import TestScenarioBuilder +from unit_tests.sources.scenario_based.scenario_builder import TestScenarioBuilder _base_success_scenario = ( TestScenarioBuilder() diff --git a/airbyte-cdk/python/unit_tests/sources/file_based/scenarios/csv_scenarios.py b/airbyte-cdk/python/unit_tests/sources/file_based/scenarios/csv_scenarios.py index e6c5824b4e19..186b12c4d4a3 100644 --- a/airbyte-cdk/python/unit_tests/sources/file_based/scenarios/csv_scenarios.py +++ b/airbyte-cdk/python/unit_tests/sources/file_based/scenarios/csv_scenarios.py @@ -11,7 +11,7 @@ from unit_tests.sources.file_based.helpers import EmptySchemaParser, LowInferenceLimitDiscoveryPolicy from unit_tests.sources.file_based.in_memory_files_source import InMemoryFilesSource from unit_tests.sources.file_based.scenarios.file_based_source_builder import FileBasedSourceBuilder -from unit_tests.sources.file_based.scenarios.scenario_builder import TestScenario, TestScenarioBuilder +from unit_tests.sources.scenario_based.scenario_builder import TestScenario, TestScenarioBuilder single_csv_scenario: TestScenario[InMemoryFilesSource] = ( TestScenarioBuilder[InMemoryFilesSource]() diff --git a/airbyte-cdk/python/unit_tests/sources/file_based/scenarios/file_based_source_builder.py b/airbyte-cdk/python/unit_tests/sources/file_based/scenarios/file_based_source_builder.py index 90deb31fe41b..439fd4f9749a 100644 --- a/airbyte-cdk/python/unit_tests/sources/file_based/scenarios/file_based_source_builder.py +++ b/airbyte-cdk/python/unit_tests/sources/file_based/scenarios/file_based_source_builder.py @@ -15,7 +15,7 @@ from airbyte_cdk.sources.file_based.schema_validation_policies import AbstractSchemaValidationPolicy from airbyte_cdk.sources.file_based.stream.cursor import AbstractFileBasedCursor from unit_tests.sources.file_based.in_memory_files_source import InMemoryFilesSource -from unit_tests.sources.file_based.scenarios.scenario_builder import SourceBuilder +from unit_tests.sources.scenario_based.scenario_builder import SourceBuilder class FileBasedSourceBuilder(SourceBuilder[InMemoryFilesSource]): diff --git a/airbyte-cdk/python/unit_tests/sources/file_based/scenarios/incremental_scenarios.py b/airbyte-cdk/python/unit_tests/sources/file_based/scenarios/incremental_scenarios.py index 3c3195fbac61..8e96171b8873 100644 --- a/airbyte-cdk/python/unit_tests/sources/file_based/scenarios/incremental_scenarios.py +++ b/airbyte-cdk/python/unit_tests/sources/file_based/scenarios/incremental_scenarios.py @@ -4,7 +4,7 @@ from unit_tests.sources.file_based.helpers import LowHistoryLimitCursor from unit_tests.sources.file_based.scenarios.file_based_source_builder import FileBasedSourceBuilder -from unit_tests.sources.file_based.scenarios.scenario_builder import IncrementalScenarioConfig, TestScenarioBuilder +from unit_tests.sources.scenario_based.scenario_builder import IncrementalScenarioConfig, TestScenarioBuilder single_csv_input_state_is_earlier_scenario = ( TestScenarioBuilder() diff --git a/airbyte-cdk/python/unit_tests/sources/file_based/scenarios/jsonl_scenarios.py b/airbyte-cdk/python/unit_tests/sources/file_based/scenarios/jsonl_scenarios.py index b4a447c4f0c0..55aa4f786a17 100644 --- a/airbyte-cdk/python/unit_tests/sources/file_based/scenarios/jsonl_scenarios.py +++ b/airbyte-cdk/python/unit_tests/sources/file_based/scenarios/jsonl_scenarios.py @@ -7,7 +7,7 @@ from airbyte_cdk.utils.traced_exception import AirbyteTracedException from unit_tests.sources.file_based.helpers import LowInferenceBytesJsonlParser, LowInferenceLimitDiscoveryPolicy from unit_tests.sources.file_based.scenarios.file_based_source_builder import FileBasedSourceBuilder -from unit_tests.sources.file_based.scenarios.scenario_builder import TestScenarioBuilder +from unit_tests.sources.scenario_based.scenario_builder import TestScenarioBuilder single_jsonl_scenario = ( TestScenarioBuilder() diff --git a/airbyte-cdk/python/unit_tests/sources/file_based/scenarios/parquet_scenarios.py b/airbyte-cdk/python/unit_tests/sources/file_based/scenarios/parquet_scenarios.py index 0852de4a361a..b9717009da67 100644 --- a/airbyte-cdk/python/unit_tests/sources/file_based/scenarios/parquet_scenarios.py +++ b/airbyte-cdk/python/unit_tests/sources/file_based/scenarios/parquet_scenarios.py @@ -9,7 +9,7 @@ from airbyte_cdk.utils.traced_exception import AirbyteTracedException from unit_tests.sources.file_based.in_memory_files_source import TemporaryParquetFilesStreamReader from unit_tests.sources.file_based.scenarios.file_based_source_builder import FileBasedSourceBuilder -from unit_tests.sources.file_based.scenarios.scenario_builder import TestScenarioBuilder +from unit_tests.sources.scenario_based.scenario_builder import TestScenarioBuilder _single_parquet_file = { "a.parquet": { diff --git a/airbyte-cdk/python/unit_tests/sources/file_based/scenarios/unstructured_scenarios.py b/airbyte-cdk/python/unit_tests/sources/file_based/scenarios/unstructured_scenarios.py index f052c4530e4a..ec4a53c88a60 100644 --- a/airbyte-cdk/python/unit_tests/sources/file_based/scenarios/unstructured_scenarios.py +++ b/airbyte-cdk/python/unit_tests/sources/file_based/scenarios/unstructured_scenarios.py @@ -6,7 +6,7 @@ import nltk from airbyte_cdk.utils.traced_exception import AirbyteTracedException from unit_tests.sources.file_based.scenarios.file_based_source_builder import FileBasedSourceBuilder -from unit_tests.sources.file_based.scenarios.scenario_builder import TestScenarioBuilder +from unit_tests.sources.scenario_based.scenario_builder import TestScenarioBuilder # import nltk data for pdf parser nltk.download("punkt") diff --git a/airbyte-cdk/python/unit_tests/sources/file_based/scenarios/user_input_schema_scenarios.py b/airbyte-cdk/python/unit_tests/sources/file_based/scenarios/user_input_schema_scenarios.py index 58d528cb7caf..8eaed6a2525d 100644 --- a/airbyte-cdk/python/unit_tests/sources/file_based/scenarios/user_input_schema_scenarios.py +++ b/airbyte-cdk/python/unit_tests/sources/file_based/scenarios/user_input_schema_scenarios.py @@ -7,7 +7,7 @@ from airbyte_cdk.test.catalog_builder import CatalogBuilder from airbyte_protocol.models import SyncMode from unit_tests.sources.file_based.scenarios.file_based_source_builder import FileBasedSourceBuilder -from unit_tests.sources.file_based.scenarios.scenario_builder import TestScenarioBuilder +from unit_tests.sources.scenario_based.scenario_builder import TestScenarioBuilder """ User input schema rules: diff --git a/airbyte-cdk/python/unit_tests/sources/file_based/scenarios/validation_policy_scenarios.py b/airbyte-cdk/python/unit_tests/sources/file_based/scenarios/validation_policy_scenarios.py index af1318dba647..05d44ce338a4 100644 --- a/airbyte-cdk/python/unit_tests/sources/file_based/scenarios/validation_policy_scenarios.py +++ b/airbyte-cdk/python/unit_tests/sources/file_based/scenarios/validation_policy_scenarios.py @@ -4,7 +4,7 @@ from airbyte_cdk.sources.file_based.exceptions import FileBasedSourceError from unit_tests.sources.file_based.scenarios.file_based_source_builder import FileBasedSourceBuilder -from unit_tests.sources.file_based.scenarios.scenario_builder import TestScenarioBuilder +from unit_tests.sources.scenario_based.scenario_builder import TestScenarioBuilder _base_single_stream_scenario = ( TestScenarioBuilder() diff --git a/airbyte-cdk/python/unit_tests/sources/file_based/test_file_based_scenarios.py b/airbyte-cdk/python/unit_tests/sources/file_based/test_file_based_scenarios.py index 6bc58b8edf96..f416e6a3190b 100644 --- a/airbyte-cdk/python/unit_tests/sources/file_based/test_file_based_scenarios.py +++ b/airbyte-cdk/python/unit_tests/sources/file_based/test_file_based_scenarios.py @@ -5,7 +5,6 @@ from pathlib import PosixPath import pytest -from _pytest.capture import CaptureFixture from airbyte_cdk.sources.abstract_source import AbstractSource from freezegun import freeze_time from unit_tests.sources.file_based.scenarios.avro_scenarios import ( @@ -97,7 +96,7 @@ single_parquet_scenario, single_partitioned_parquet_scenario, ) -from unit_tests.sources.file_based.scenarios.scenario_builder import TestScenario +from unit_tests.sources.scenario_based.scenario_builder import TestScenario from unit_tests.sources.file_based.scenarios.unstructured_scenarios import ( corrupted_file_scenario, no_file_extension_unstructured_scenario, @@ -126,7 +125,7 @@ wait_for_rediscovery_scenario_multi_stream, wait_for_rediscovery_scenario_single_stream, ) -from unit_tests.sources.file_based.test_scenarios import verify_check, verify_discover, verify_read, verify_spec +from unit_tests.sources.scenario_based.helpers import verify_check, verify_discover, verify_read, verify_spec discover_scenarios = [ csv_multi_stream_scenario, @@ -247,8 +246,8 @@ @pytest.mark.parametrize("scenario", discover_scenarios, ids=[s.name for s in discover_scenarios]) -def test_file_based_discover(capsys: CaptureFixture[str], tmp_path: PosixPath, scenario: TestScenario[AbstractSource]) -> None: - verify_discover(capsys, tmp_path, scenario) +def test_file_based_discover(tmp_path: PosixPath, scenario: TestScenario[AbstractSource]) -> None: + verify_discover(tmp_path, scenario) @pytest.mark.parametrize("scenario", read_scenarios, ids=[s.name for s in read_scenarios]) @@ -258,10 +257,10 @@ def test_file_based_read(scenario: TestScenario[AbstractSource]) -> None: @pytest.mark.parametrize("scenario", spec_scenarios, ids=[c.name for c in spec_scenarios]) -def test_file_based_spec(capsys: CaptureFixture[str], scenario: TestScenario[AbstractSource]) -> None: - verify_spec(capsys, scenario) +def test_file_based_spec(scenario: TestScenario[AbstractSource]) -> None: + verify_spec(scenario) @pytest.mark.parametrize("scenario", check_scenarios, ids=[c.name for c in check_scenarios]) -def test_file_based_check(capsys: CaptureFixture[str], tmp_path: PosixPath, scenario: TestScenario[AbstractSource]) -> None: - verify_check(capsys, tmp_path, scenario) +def test_file_based_check(tmp_path: PosixPath, scenario: TestScenario[AbstractSource]) -> None: + verify_check(tmp_path, scenario) diff --git a/airbyte-cdk/python/unit_tests/sources/scenario_based/__init__.py b/airbyte-cdk/python/unit_tests/sources/scenario_based/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/airbyte-cdk/python/unit_tests/sources/file_based/test_scenarios.py b/airbyte-cdk/python/unit_tests/sources/scenario_based/helpers.py similarity index 81% rename from airbyte-cdk/python/unit_tests/sources/file_based/test_scenarios.py rename to airbyte-cdk/python/unit_tests/sources/scenario_based/helpers.py index 747d22a31a1f..0e19cec00c6e 100644 --- a/airbyte-cdk/python/unit_tests/sources/file_based/test_scenarios.py +++ b/airbyte-cdk/python/unit_tests/sources/scenario_based/helpers.py @@ -3,33 +3,34 @@ # import json +import logging import math from pathlib import Path, PosixPath from typing import Any, Dict, List, Mapping, Optional, Union import pytest -from _pytest.capture import CaptureFixture from _pytest.reports import ExceptionInfo -from airbyte_cdk.entrypoint import launch +from airbyte_cdk.entrypoint import get_source_iter from airbyte_cdk.models import AirbyteAnalyticsTraceMessage, SyncMode from airbyte_cdk.sources import AbstractSource +from airbyte_cdk.sources.utils.slice_logger import SliceLogger from airbyte_cdk.test.entrypoint_wrapper import EntrypointOutput from airbyte_cdk.test.entrypoint_wrapper import read as entrypoint_read from airbyte_cdk.utils.traced_exception import AirbyteTracedException from airbyte_protocol.models import AirbyteLogMessage, AirbyteMessage, ConfiguredAirbyteCatalog -from unit_tests.sources.file_based.scenarios.scenario_builder import TestScenario +from unit_tests.sources.scenario_based.scenario_builder import TestScenario -def verify_discover(capsys: CaptureFixture[str], tmp_path: PosixPath, scenario: TestScenario[AbstractSource]) -> None: +def verify_discover(tmp_path: PosixPath, scenario: TestScenario[AbstractSource]) -> None: expected_exc, expected_msg = scenario.expected_discover_error expected_logs = scenario.expected_logs if expected_exc: with pytest.raises(expected_exc) as exc: - discover(capsys, tmp_path, scenario) + discover(tmp_path, scenario) if expected_msg: assert expected_msg in get_error_message_from_exc(exc) elif scenario.expected_catalog: - output = discover(capsys, tmp_path, scenario) + output = discover(tmp_path, scenario) catalog, logs = output["catalog"], output["logs"] assert catalog == scenario.expected_catalog if expected_logs: @@ -117,50 +118,47 @@ def _verify_expected_logs(logs: List[AirbyteLogMessage], expected_logs: Optional assert expected_message in actual_message -def verify_spec(capsys: CaptureFixture[str], scenario: TestScenario[AbstractSource]) -> None: - assert spec(capsys, scenario) == scenario.expected_spec +def verify_spec(scenario: TestScenario[AbstractSource]) -> None: + assert spec(scenario) == scenario.expected_spec -def verify_check(capsys: CaptureFixture[str], tmp_path: PosixPath, scenario: TestScenario[AbstractSource]) -> None: +def verify_check(tmp_path: PosixPath, scenario: TestScenario[AbstractSource]) -> None: expected_exc, expected_msg = scenario.expected_check_error if expected_exc: with pytest.raises(expected_exc): - output = check(capsys, tmp_path, scenario) + output = check(tmp_path, scenario) if expected_msg: # expected_msg is a string. what's the expected value field? assert expected_msg.value in output["message"] # type: ignore assert output["status"] == scenario.expected_check_status else: - output = check(capsys, tmp_path, scenario) + output = check(tmp_path, scenario) assert output["status"] == scenario.expected_check_status -def spec(capsys: CaptureFixture[str], scenario: TestScenario[AbstractSource]) -> Mapping[str, Any]: - launch( +def spec(scenario: TestScenario[AbstractSource]) -> Mapping[str, Any]: + output = get_source_iter( scenario.source, ["spec"], ) - captured = capsys.readouterr() - return json.loads(captured.out.splitlines()[0])["spec"] # type: ignore + return json.loads(next(output))["spec"] # type: ignore -def check(capsys: CaptureFixture[str], tmp_path: PosixPath, scenario: TestScenario[AbstractSource]) -> Dict[str, Any]: - launch( +def check(tmp_path: PosixPath, scenario: TestScenario[AbstractSource]) -> Dict[str, Any]: + output = get_source_iter( scenario.source, ["check", "--config", make_file(tmp_path / "config.json", scenario.config)], ) - captured = capsys.readouterr() - return json.loads(captured.out.splitlines()[0])["connectionStatus"] # type: ignore + return json.loads(next(output))["connectionStatus"] # type: ignore -def discover(capsys: CaptureFixture[str], tmp_path: PosixPath, scenario: TestScenario[AbstractSource]) -> Dict[str, Any]: - launch( +def discover(tmp_path: PosixPath, scenario: TestScenario[AbstractSource]) -> Dict[str, Any]: + output = [json.loads(o) for o in get_source_iter( scenario.source, ["discover", "--config", make_file(tmp_path / "config.json", scenario.config)], - ) - output = [json.loads(line) for line in capsys.readouterr().out.splitlines()] + )] [catalog] = [o["catalog"] for o in output if o.get("catalog")] # type: ignore return { "catalog": catalog, @@ -194,3 +192,8 @@ def get_error_message_from_exc(exc: ExceptionInfo[Any]) -> str: if isinstance(exc.value, AirbyteTracedException): return exc.value.message return str(exc.value.args[0]) + + +class NeverLogSliceLogger(SliceLogger): + def should_log_slice_message(self, logger: logging.Logger) -> bool: + return False diff --git a/airbyte-cdk/python/unit_tests/sources/file_based/scenarios/scenario_builder.py b/airbyte-cdk/python/unit_tests/sources/scenario_based/scenario_builder.py similarity index 100% rename from airbyte-cdk/python/unit_tests/sources/file_based/scenarios/scenario_builder.py rename to airbyte-cdk/python/unit_tests/sources/scenario_based/scenario_builder.py diff --git a/airbyte-cdk/python/unit_tests/sources/streams/concurrent/scenarios/incremental_scenarios.py b/airbyte-cdk/python/unit_tests/sources/streams/concurrent/scenarios/incremental_scenarios.py index 72a0425bc098..6fd76b710387 100644 --- a/airbyte-cdk/python/unit_tests/sources/streams/concurrent/scenarios/incremental_scenarios.py +++ b/airbyte-cdk/python/unit_tests/sources/streams/concurrent/scenarios/incremental_scenarios.py @@ -3,7 +3,7 @@ # from airbyte_cdk.sources.streams.concurrent.cursor import CursorField from airbyte_cdk.sources.streams.concurrent.state_converters.abstract_stream_state_converter import ConcurrencyCompatibleStateType -from unit_tests.sources.file_based.scenarios.scenario_builder import IncrementalScenarioConfig, TestScenarioBuilder +from unit_tests.sources.scenario_based.scenario_builder import IncrementalScenarioConfig, TestScenarioBuilder from unit_tests.sources.streams.concurrent.scenarios.stream_facade_builder import StreamFacadeSourceBuilder from unit_tests.sources.streams.concurrent.scenarios.utils import MockStream diff --git a/airbyte-cdk/python/unit_tests/sources/streams/concurrent/scenarios/stream_facade_builder.py b/airbyte-cdk/python/unit_tests/sources/streams/concurrent/scenarios/stream_facade_builder.py index 716eb5508eaf..dc7724b6ba50 100644 --- a/airbyte-cdk/python/unit_tests/sources/streams/concurrent/scenarios/stream_facade_builder.py +++ b/airbyte-cdk/python/unit_tests/sources/streams/concurrent/scenarios/stream_facade_builder.py @@ -16,7 +16,7 @@ from airbyte_cdk.sources.streams.concurrent.cursor import ConcurrentCursor, CursorField, NoopCursor from airbyte_cdk.sources.streams.concurrent.state_converters.datetime_stream_state_converter import EpochValueConcurrentStreamStateConverter from airbyte_protocol.models import ConfiguredAirbyteStream -from unit_tests.sources.file_based.scenarios.scenario_builder import SourceBuilder +from unit_tests.sources.scenario_based.scenario_builder import SourceBuilder from unit_tests.sources.streams.concurrent.scenarios.thread_based_concurrent_stream_source_builder import NeverLogSliceLogger _CURSOR_FIELD = "cursor_field" diff --git a/airbyte-cdk/python/unit_tests/sources/streams/concurrent/scenarios/stream_facade_scenarios.py b/airbyte-cdk/python/unit_tests/sources/streams/concurrent/scenarios/stream_facade_scenarios.py index ae66d3a44374..d780036b2ccf 100644 --- a/airbyte-cdk/python/unit_tests/sources/streams/concurrent/scenarios/stream_facade_scenarios.py +++ b/airbyte-cdk/python/unit_tests/sources/streams/concurrent/scenarios/stream_facade_scenarios.py @@ -2,7 +2,7 @@ # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # from airbyte_cdk.sources.streams.concurrent.cursor import CursorField -from unit_tests.sources.file_based.scenarios.scenario_builder import IncrementalScenarioConfig, TestScenarioBuilder +from unit_tests.sources.scenario_based.scenario_builder import IncrementalScenarioConfig, TestScenarioBuilder from unit_tests.sources.streams.concurrent.scenarios.stream_facade_builder import StreamFacadeSourceBuilder from unit_tests.sources.streams.concurrent.scenarios.utils import MockStream diff --git a/airbyte-cdk/python/unit_tests/sources/streams/concurrent/scenarios/test_concurrent_scenarios.py b/airbyte-cdk/python/unit_tests/sources/streams/concurrent/scenarios/test_concurrent_scenarios.py index af2249873035..db6d9b5efb9a 100644 --- a/airbyte-cdk/python/unit_tests/sources/streams/concurrent/scenarios/test_concurrent_scenarios.py +++ b/airbyte-cdk/python/unit_tests/sources/streams/concurrent/scenarios/test_concurrent_scenarios.py @@ -5,10 +5,9 @@ from pathlib import PosixPath import pytest -from _pytest.capture import CaptureFixture from freezegun import freeze_time -from unit_tests.sources.file_based.scenarios.scenario_builder import TestScenario -from unit_tests.sources.file_based.test_scenarios import verify_discover, verify_read +from unit_tests.sources.scenario_based.scenario_builder import TestScenario +from unit_tests.sources.scenario_based.helpers import verify_discover, verify_read from unit_tests.sources.streams.concurrent.scenarios.incremental_scenarios import ( test_incremental_stream_with_slice_boundaries_no_input_state, test_incremental_stream_with_slice_boundaries_with_concurrent_state, @@ -72,5 +71,5 @@ def test_concurrent_read(scenario: TestScenario) -> None: @pytest.mark.parametrize("scenario", scenarios, ids=[s.name for s in scenarios]) -def test_concurrent_discover(capsys: CaptureFixture[str], tmp_path: PosixPath, scenario: TestScenario) -> None: - verify_discover(capsys, tmp_path, scenario) +def test_concurrent_discover(tmp_path: PosixPath, scenario: TestScenario) -> None: + verify_discover(tmp_path, scenario) diff --git a/airbyte-cdk/python/unit_tests/sources/streams/concurrent/scenarios/thread_based_concurrent_stream_scenarios.py b/airbyte-cdk/python/unit_tests/sources/streams/concurrent/scenarios/thread_based_concurrent_stream_scenarios.py index 2f4ab9b9fccb..8828b599acbb 100644 --- a/airbyte-cdk/python/unit_tests/sources/streams/concurrent/scenarios/thread_based_concurrent_stream_scenarios.py +++ b/airbyte-cdk/python/unit_tests/sources/streams/concurrent/scenarios/thread_based_concurrent_stream_scenarios.py @@ -6,7 +6,7 @@ from airbyte_cdk.sources.message import InMemoryMessageRepository from airbyte_cdk.sources.streams.concurrent.default_stream import DefaultStream from airbyte_cdk.sources.streams.concurrent.partitions.record import Record -from unit_tests.sources.file_based.scenarios.scenario_builder import TestScenarioBuilder +from unit_tests.sources.scenario_based.scenario_builder import TestScenarioBuilder from unit_tests.sources.streams.concurrent.scenarios.thread_based_concurrent_stream_source_builder import ( AlwaysAvailableAvailabilityStrategy, ConcurrentSourceBuilder, diff --git a/airbyte-cdk/python/unit_tests/sources/streams/concurrent/scenarios/thread_based_concurrent_stream_source_builder.py b/airbyte-cdk/python/unit_tests/sources/streams/concurrent/scenarios/thread_based_concurrent_stream_source_builder.py index 943aea30dbba..db29bae41516 100644 --- a/airbyte-cdk/python/unit_tests/sources/streams/concurrent/scenarios/thread_based_concurrent_stream_source_builder.py +++ b/airbyte-cdk/python/unit_tests/sources/streams/concurrent/scenarios/thread_based_concurrent_stream_source_builder.py @@ -18,9 +18,9 @@ from airbyte_cdk.sources.streams.concurrent.partitions.partition_generator import PartitionGenerator from airbyte_cdk.sources.streams.concurrent.partitions.record import Record from airbyte_cdk.sources.streams.core import StreamData -from airbyte_cdk.sources.utils.slice_logger import SliceLogger from airbyte_protocol.models import ConfiguredAirbyteStream -from unit_tests.sources.file_based.scenarios.scenario_builder import SourceBuilder +from unit_tests.sources.scenario_based.helpers import NeverLogSliceLogger +from unit_tests.sources.scenario_based.scenario_builder import SourceBuilder class LegacyStream(Stream): @@ -134,8 +134,3 @@ def set_message_repository(self, message_repository: MessageRepository) -> "Conc class AlwaysAvailableAvailabilityStrategy(AbstractAvailabilityStrategy): def check_availability(self, logger: logging.Logger) -> StreamAvailability: return StreamAvailable() - - -class NeverLogSliceLogger(SliceLogger): - def should_log_slice_message(self, logger: logging.Logger) -> bool: - return False diff --git a/airbyte-cdk/python/unit_tests/sources/streams/http/test_availability_strategy.py b/airbyte-cdk/python/unit_tests/sources/streams/http/test_availability_strategy.py index b63af7973854..32e96cf16717 100644 --- a/airbyte-cdk/python/unit_tests/sources/streams/http/test_availability_strategy.py +++ b/airbyte-cdk/python/unit_tests/sources/streams/http/test_availability_strategy.py @@ -87,7 +87,7 @@ def read_records(self, *args, **kvargs): class MockResponseWithJsonContents(requests.Response, mocker.MagicMock): def __init__(self, *args, **kvargs): mocker.MagicMock.__init__(self) - requests.Response.__init__(self, **kvargs) + requests.Response.__init__(self) self.json = mocker.MagicMock() class MockSource(AbstractSource): diff --git a/airbyte-cdk/python/unit_tests/utils/test_rate_limiting.py b/airbyte-cdk/python/unit_tests/utils/test_rate_limiting.py index d1ed294b930d..85ddf535cae6 100644 --- a/airbyte-cdk/python/unit_tests/utils/test_rate_limiting.py +++ b/airbyte-cdk/python/unit_tests/utils/test_rate_limiting.py @@ -2,6 +2,7 @@ # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +import aiohttp import pytest from airbyte_cdk.sources.streams.http.rate_limiting import default_backoff_handler from requests import exceptions @@ -18,6 +19,10 @@ def helper_with_exceptions(exception_type): (1, 1, 0, exceptions.ReadTimeout), (2, 2, 1, exceptions.ConnectionError), (3, 3, 1, exceptions.ChunkedEncodingError), + (1, None, 1, aiohttp.ClientPayloadError), + (1, None, 1, aiohttp.ServerTimeoutError), + (2, 2, 1, aiohttp.ServerConnectionError), + (2, 2, 1, aiohttp.ServerDisconnectedError), ], ) def test_default_backoff_handler(max_tries: int, max_time: int, factor: int, exception_to_raise: Exception): diff --git a/airbyte-integrations/connectors/source-salesforce/main.py b/airbyte-integrations/connectors/source-salesforce/main.py index 5ec9f05e1042..aa92e9371817 100644 --- a/airbyte-integrations/connectors/source-salesforce/main.py +++ b/airbyte-integrations/connectors/source-salesforce/main.py @@ -10,16 +10,18 @@ from airbyte_cdk.entrypoint import AirbyteEntrypoint, launch from airbyte_cdk.models import AirbyteErrorTraceMessage, AirbyteMessage, AirbyteTraceMessage, TraceType, Type -from source_salesforce import SourceSalesforce +from source_salesforce import AsyncSourceSalesforce, SalesforceSourceDispatcher, SourceSalesforce def _get_source(args: List[str]): catalog_path = AirbyteEntrypoint.extract_catalog(args) config_path = AirbyteEntrypoint.extract_config(args) + catalog = AsyncSourceSalesforce.read_catalog(catalog_path) if catalog_path else None + config = AsyncSourceSalesforce.read_config(config_path) if config_path else None try: - return SourceSalesforce( - SourceSalesforce.read_catalog(catalog_path) if catalog_path else None, - SourceSalesforce.read_config(config_path) if config_path else None, + return SalesforceSourceDispatcher( + AsyncSourceSalesforce(catalog, config), + SourceSalesforce(catalog, config) ) except Exception as error: print( diff --git a/airbyte-integrations/connectors/source-salesforce/setup.py b/airbyte-integrations/connectors/source-salesforce/setup.py index 22e6250d4660..f1b4ddfaedd8 100644 --- a/airbyte-integrations/connectors/source-salesforce/setup.py +++ b/airbyte-integrations/connectors/source-salesforce/setup.py @@ -7,7 +7,7 @@ MAIN_REQUIREMENTS = ["airbyte-cdk~=0.55.2", "pandas"] -TEST_REQUIREMENTS = ["freezegun", "pytest~=6.1", "pytest-mock~=3.6", "requests-mock~=1.9.3", "pytest-timeout"] +TEST_REQUIREMENTS = ["aioresponses", "freezegun", "pytest~=6.1", "pytest-asyncio", "pytest-mock~=3.6", "requests-mock~=1.9.3", "pytest-timeout"] setup( name="source_salesforce", diff --git a/airbyte-integrations/connectors/source-salesforce/source_salesforce/__init__.py b/airbyte-integrations/connectors/source-salesforce/source_salesforce/__init__.py index 60d17cf1e8f0..9cb637717902 100644 --- a/airbyte-integrations/connectors/source-salesforce/source_salesforce/__init__.py +++ b/airbyte-integrations/connectors/source-salesforce/source_salesforce/__init__.py @@ -2,6 +2,7 @@ # Copyright (c) 2021 Airbyte, Inc., all rights reserved. # +from .async_salesforce.source import AsyncSourceSalesforce, SalesforceSourceDispatcher from .source import SourceSalesforce -__all__ = ["SourceSalesforce"] +__all__ = ["AsyncSourceSalesforce", "SalesforceSourceDispatcher", "AsyncSourceSalesforce"] diff --git a/airbyte-integrations/connectors/source-salesforce/source_salesforce/async_salesforce/__init__.py b/airbyte-integrations/connectors/source-salesforce/source_salesforce/async_salesforce/__init__.py new file mode 100644 index 000000000000..69de4c37ac7d --- /dev/null +++ b/airbyte-integrations/connectors/source-salesforce/source_salesforce/async_salesforce/__init__.py @@ -0,0 +1,3 @@ +from .source import AsyncSourceSalesforce, SalesforceSourceDispatcher + +__all__ = ["AsyncSourceSalesforce", "SalesforceSourceDispatcher"] diff --git a/airbyte-integrations/connectors/source-salesforce/source_salesforce/async_salesforce/availability_strategy.py b/airbyte-integrations/connectors/source-salesforce/source_salesforce/async_salesforce/availability_strategy.py new file mode 100644 index 000000000000..cbb9eac0a16f --- /dev/null +++ b/airbyte-integrations/connectors/source-salesforce/source_salesforce/async_salesforce/availability_strategy.py @@ -0,0 +1,36 @@ +# +# Copyright (c) 2023 Airbyte, Inc., all rights reserved. +# + +import logging +import typing +from typing import Optional, Tuple + +from airbyte_cdk.sources.async_cdk.streams.http.availability_strategy_async import AsyncHttpAvailabilityStrategy +from airbyte_cdk.sources.streams import Stream +from airbyte_cdk.sources.streams.http.utils import HttpError +from requests import codes + +if typing.TYPE_CHECKING: + from airbyte_cdk.sources import Source + + +class AsyncSalesforceAvailabilityStrategy(AsyncHttpAvailabilityStrategy): + def handle_http_error( + self, stream: Stream, logger: logging.Logger, source: Optional["Source"], error: HttpError + ) -> Tuple[bool, Optional[str]]: + """ + There are several types of Salesforce sobjects that require additional processing: + 1. Sobjects for which the user, after setting up the data using Airbyte, restricted access, + and we will receive 403 HTTP errors. + 2. There are streams that do not allow you to make a sample using Salesforce `query` or `queryAll`. + And since we use a dynamic method of generating streams for Salesforce connector - at the stage of discover, + we cannot filter out these streams, so we check for them before reading from the streams. + """ + if error.status_code in [codes.FORBIDDEN, codes.BAD_REQUEST]: + error_data = error.json()[0] + error_code = error_data.get("errorCode", "") + if error_code != "REQUEST_LIMIT_EXCEEDED" or error_code == "INVALID_TYPE_FOR_OPERATION": + return False, f"Cannot receive data for stream '{stream.name}', error message: '{error_data.get('message')}'" + return True, None + raise error diff --git a/airbyte-integrations/connectors/source-salesforce/source_salesforce/async_salesforce/rate_limiting.py b/airbyte-integrations/connectors/source-salesforce/source_salesforce/async_salesforce/rate_limiting.py new file mode 100644 index 000000000000..2624230055a9 --- /dev/null +++ b/airbyte-integrations/connectors/source-salesforce/source_salesforce/async_salesforce/rate_limiting.py @@ -0,0 +1,52 @@ +# +# Copyright (c) 2023 Airbyte, Inc., all rights reserved. +# + +import logging +import sys + +import aiohttp +import backoff +from airbyte_cdk.sources.async_cdk.streams.http.exceptions_async import AsyncDefaultBackoffException +from requests import codes, exceptions # type: ignore[import] + +TRANSIENT_EXCEPTIONS = ( + AsyncDefaultBackoffException, + aiohttp.ClientPayloadError, + aiohttp.ServerTimeoutError, + aiohttp.ServerConnectionError, + aiohttp.ServerDisconnectedError, +) + +logger = logging.getLogger("airbyte") + + +def default_backoff_handler(max_tries: int, factor: int, **kwargs): + def log_retry_attempt(details): + _, exc, _ = sys.exc_info() + logger.info(str(exc)) + logger.info(f"Caught retryable error after {details['tries']} tries. Waiting {details['wait']} seconds then retrying...") + + def should_give_up(exc): + give_up = exc.response is not None and exc.response.status_code != codes.too_many_requests and 400 <= exc.status_code < 500 + + # Salesforce can return an error with a limit using a 403 code error. + if exc.status_code == codes.forbidden: + error_data = exc.json()[0] + if error_data.get("errorCode", "") == "REQUEST_LIMIT_EXCEEDED": + give_up = True + + if give_up: + logger.info(f"Giving up for returned HTTP status: {exc.status_code}, body: {exc.text}") + return give_up + + return backoff.on_exception( + backoff.expo, + TRANSIENT_EXCEPTIONS, + jitter=None, + on_backoff=log_retry_attempt, + giveup=should_give_up, + max_tries=max_tries, + factor=factor, + **kwargs, + ) diff --git a/airbyte-integrations/connectors/source-salesforce/source_salesforce/async_salesforce/source.py b/airbyte-integrations/connectors/source-salesforce/source_salesforce/async_salesforce/source.py new file mode 100644 index 000000000000..4d22df1d4fd9 --- /dev/null +++ b/airbyte-integrations/connectors/source-salesforce/source_salesforce/async_salesforce/source.py @@ -0,0 +1,238 @@ +# +# Copyright (c) 2023 Airbyte, Inc., all rights reserved. +# + +import logging +from datetime import datetime +from typing import Any, Iterator, List, Mapping, MutableMapping, Optional, Tuple, Union + +import requests +from airbyte_cdk import AirbyteLogger +from airbyte_cdk.logger import AirbyteLogFormatter +from airbyte_cdk.models import AirbyteMessage, AirbyteStateMessage, ConfiguredAirbyteCatalog, ConfiguredAirbyteStream, Level, SyncMode +from airbyte_cdk.sources.async_cdk.abstract_source_async import AsyncAbstractSource +from airbyte_cdk.sources.async_cdk.source_dispatcher import SourceDispatcher +from airbyte_cdk.sources.connector_state_manager import ConnectorStateManager +from airbyte_cdk.sources.message import InMemoryMessageRepository +from airbyte_cdk.sources.streams import Stream +from airbyte_cdk.sources.streams.http.auth import TokenAuthenticator +from airbyte_cdk.sources.streams.http.utils import HttpError +from airbyte_cdk.sources.utils.schema_helpers import InternalConfig +from airbyte_cdk.utils.traced_exception import AirbyteTracedException +from dateutil.relativedelta import relativedelta +from requests import codes, exceptions # type: ignore[import] + +from source_salesforce.api import PARENT_SALESFORCE_OBJECTS, UNSUPPORTED_BULK_API_SALESFORCE_OBJECTS, UNSUPPORTED_FILTERING_STREAMS, Salesforce +from .streams import ( + BulkIncrementalSalesforceStream, + BulkSalesforceStream, + BulkSalesforceSubStream, + Describe, + IncrementalRestSalesforceStream, + RestSalesforceStream, + RestSalesforceSubStream, +) + +_DEFAULT_CONCURRENCY = 10 +_MAX_CONCURRENCY = 10 +logger = logging.getLogger("airbyte") + + +class AirbyteStopSync(AirbyteTracedException): + pass + + +class AsyncSourceSalesforce(AsyncAbstractSource): + DATETIME_FORMAT = "%Y-%m-%dT%H:%M:%SZ" + START_DATE_OFFSET_IN_YEARS = 2 + MAX_WORKERS = 5 + + message_repository = InMemoryMessageRepository(Level(AirbyteLogFormatter.level_mapping[logger.level])) + + def __init__(self, catalog: Optional[ConfiguredAirbyteCatalog], config: Optional[Mapping[str, Any]], **kwargs): + if config: + concurrency_level = min(config.get("num_workers", _DEFAULT_CONCURRENCY), _MAX_CONCURRENCY) + else: + concurrency_level = _DEFAULT_CONCURRENCY + logger.info(f"Using concurrent cdk with concurrency level {concurrency_level}") + self.catalog = catalog + super().__init__() + + @staticmethod + async def _get_sf_object(config: Mapping[str, Any]) -> Salesforce: + sf = Salesforce(**config) + sf.login() + return sf + + async def check_connection(self, logger: AirbyteLogger, config: Mapping[str, Any]) -> Tuple[bool, Optional[str]]: + try: + salesforce = await self._get_sf_object(config) + salesforce.describe() + except exceptions.HTTPError as error: + error_msg = f"An error occurred: {error.response.text}" + try: + error_data = error.response.json()[0] + except (KeyError, requests.exceptions.JSONDecodeError): + pass + else: + error_code = error_data.get("errorCode") + if error.response.status_code == codes.FORBIDDEN and error_code == "REQUEST_LIMIT_EXCEEDED": + logger.warn(f"API Call limit is exceeded. Error message: '{error_data.get('message')}'") + error_msg = "API Call limit is exceeded" + return False, error_msg + return True, None + + @classmethod + def _get_api_type(cls, stream_name: str, json_schema: Mapping[str, Any], force_use_bulk_api: bool) -> str: + """Get proper API type: rest or bulk""" + # Salesforce BULK API currently does not support loading fields with data type base64 and compound data + properties = json_schema.get("properties", {}) + properties_not_supported_by_bulk = { + key: value for key, value in properties.items() if value.get("format") == "base64" or "object" in value["type"] + } + rest_only = stream_name in UNSUPPORTED_BULK_API_SALESFORCE_OBJECTS + if rest_only: + logger.warning(f"BULK API is not supported for stream: {stream_name}") + return "rest" + if force_use_bulk_api and properties_not_supported_by_bulk: + logger.warning( + f"Following properties will be excluded from stream: {stream_name} due to BULK API limitations: {list(properties_not_supported_by_bulk)}" + ) + return "bulk" + if properties_not_supported_by_bulk: + return "rest" + return "bulk" + + @classmethod + def _get_stream_type(cls, stream_name: str, api_type: str): + """Get proper stream class: full_refresh, incremental or substream + + SubStreams (like ContentDocumentLink) do not support incremental sync because of query restrictions, look here: + https://developer.salesforce.com/docs/atlas.en-us.object_reference.meta/object_reference/sforce_api_objects_contentdocumentlink.htm + """ + parent_name = PARENT_SALESFORCE_OBJECTS.get(stream_name, {}).get("parent_name") + if api_type == "rest": + full_refresh = RestSalesforceSubStream if parent_name else RestSalesforceStream + incremental = IncrementalRestSalesforceStream + elif api_type == "bulk": + full_refresh = BulkSalesforceSubStream if parent_name else BulkSalesforceStream + incremental = BulkIncrementalSalesforceStream + else: + raise Exception(f"Stream {stream_name} cannot be processed by REST or BULK API.") + return full_refresh, incremental + + @classmethod + def prepare_stream(cls, stream_name: str, json_schema, sobject_options, sf_object, authenticator, config): + """Choose proper stream class: syncMode(full_refresh/incremental), API type(Rest/Bulk), SubStream""" + pk, replication_key = sf_object.get_pk_and_replication_key(json_schema) + stream_kwargs = { + "stream_name": stream_name, + "schema": json_schema, + "pk": pk, + "sobject_options": sobject_options, + "sf_api": sf_object, + "authenticator": authenticator, + "start_date": config.get("start_date"), + } + + api_type = cls._get_api_type(stream_name, json_schema, config.get("force_use_bulk_api", False)) + full_refresh, incremental = cls._get_stream_type(stream_name, api_type) + if replication_key and stream_name not in UNSUPPORTED_FILTERING_STREAMS: + stream_class = incremental + stream_kwargs["replication_key"] = replication_key + else: + stream_class = full_refresh + + return stream_class, stream_kwargs + + @classmethod + async def generate_streams( + cls, + config: Mapping[str, Any], + stream_objects: Mapping[str, Any], + sf_object: Salesforce, + ) -> List[Stream]: + """Generates a list of stream by their names. It can be used for different tests too""" + authenticator = TokenAuthenticator(sf_object.access_token) + schemas = sf_object.generate_schemas(stream_objects) + default_args = [sf_object, authenticator, config] + streams = [] + for stream_name, sobject_options in stream_objects.items(): + json_schema = schemas.get(stream_name, {}) + + stream_class, kwargs = cls.prepare_stream(stream_name, json_schema, sobject_options, *default_args) + + parent_name = PARENT_SALESFORCE_OBJECTS.get(stream_name, {}).get("parent_name") + if parent_name: + # get minimal schema required for getting proper class name full_refresh/incremental, rest/bulk + parent_schema = PARENT_SALESFORCE_OBJECTS.get(stream_name, {}).get("schema_minimal") + parent_class, parent_kwargs = cls.prepare_stream(parent_name, parent_schema, sobject_options, *default_args) + kwargs["parent"] = parent_class(**parent_kwargs) + + stream = stream_class(**kwargs) + + api_type = cls._get_api_type(stream_name, json_schema, config.get("force_use_bulk_api", False)) + if api_type == "rest" and not stream.primary_key and stream.too_many_properties: + logger.warning( + f"Can not instantiate stream {stream_name}. It is not supported by the BULK API and can not be " + "implemented via REST because the number of its properties exceeds the limit and it lacks a primary key." + ) + continue + streams.append(stream) + return streams + + async def streams(self, config: Mapping[str, Any]) -> List[Stream]: + if not config.get("start_date"): + config["start_date"] = (datetime.now() - relativedelta(years=self.START_DATE_OFFSET_IN_YEARS)).strftime(self.DATETIME_FORMAT) + sf = await self._get_sf_object(config) + stream_objects = sf.get_validated_streams(config=config, catalog=self.catalog) + streams = await self.generate_streams(config, stream_objects, sf) + streams.append(Describe(sf_api=sf, catalog=self.catalog)) + # TODO: incorporate state & ConcurrentCursor when we support incremental + configured_streams = [] + for stream in streams: + configured_streams.append(stream) + return configured_streams + + def _get_sync_mode_from_catalog(self, stream: Stream) -> Optional[SyncMode]: + if self.catalog: + for catalog_stream in self.catalog.streams: + if stream.name == catalog_stream.stream.name: + return catalog_stream.sync_mode + return None + + async def read_stream( + self, + logger: logging.Logger, + stream_instance: Stream, + configured_stream: ConfiguredAirbyteStream, + state_manager: ConnectorStateManager, + internal_config: InternalConfig, + ) -> Iterator[AirbyteMessage]: + try: + async for record in super().read_stream(logger, stream_instance, configured_stream, state_manager, internal_config): + yield record + except HttpError as error: + error_data = error.json() + error_code = error_data.get("errorCode") + url = error.url + if error.status_code == codes.FORBIDDEN and error_code == "REQUEST_LIMIT_EXCEEDED": + logger.warning(f"API Call {url} limit is exceeded. Error message: '{error_data.get('message')}'") + raise AirbyteStopSync() # if got 403 rate limit response, finish the sync with success. + raise error + + +class SalesforceSourceDispatcher(SourceDispatcher): + def read( + self, + logger: logging.Logger, + config: Mapping[str, Any], + catalog: ConfiguredAirbyteCatalog, + state: Union[List[AirbyteStateMessage], MutableMapping[str, Any]] = None, + ) -> Iterator[AirbyteMessage]: + # save for use inside streams method + self.catalog = catalog + try: + yield from super().read(logger, config, catalog, state) + except AirbyteStopSync: + logger.info(f"Finished syncing {self.async_source.name}") diff --git a/airbyte-integrations/connectors/source-salesforce/source_salesforce/async_salesforce/streams.py b/airbyte-integrations/connectors/source-salesforce/source_salesforce/async_salesforce/streams.py new file mode 100644 index 000000000000..d9d30bfcc1c0 --- /dev/null +++ b/airbyte-integrations/connectors/source-salesforce/source_salesforce/async_salesforce/streams.py @@ -0,0 +1,817 @@ +# +# Copyright (c) 2023 Airbyte, Inc., all rights reserved. +# + +import csv +import ctypes +import math +import os +import time +import urllib.parse +import uuid +from abc import ABC +from typing import Any, Callable, Iterable, List, Mapping, MutableMapping, Optional, Tuple, Type, Union + +import aiohttp +import pandas as pd +import pendulum +import requests # type: ignore[import] +from airbyte_cdk.models import ConfiguredAirbyteCatalog, FailureType, SyncMode +from airbyte_cdk.sources.async_cdk.streams.http.http_async import AsyncHttpStream, AsyncHttpSubStream +from airbyte_cdk.sources.streams.availability_strategy import AvailabilityStrategy +from airbyte_cdk.sources.streams.core import Stream, StreamData +from airbyte_cdk.sources.streams.http.utils import HttpError +from airbyte_cdk.sources.utils.transform import TransformConfig, TypeTransformer +from airbyte_cdk.utils import AirbyteTracedException +from numpy import nan +from pendulum import DateTime # type: ignore[attr-defined] +from requests import codes +from requests.models import PreparedRequest + +from source_salesforce.api import PARENT_SALESFORCE_OBJECTS, UNSUPPORTED_FILTERING_STREAMS, Salesforce +from source_salesforce.async_salesforce.availability_strategy import AsyncSalesforceAvailabilityStrategy +from source_salesforce.async_salesforce.rate_limiting import default_backoff_handler +from source_salesforce.exceptions import SalesforceException, TmpFileIOError + +# https://stackoverflow.com/a/54517228 +CSV_FIELD_SIZE_LIMIT = int(ctypes.c_ulong(-1).value // 2) +csv.field_size_limit(CSV_FIELD_SIZE_LIMIT) + +DEFAULT_ENCODING = "utf-8" + + +class SalesforceStream(AsyncHttpStream, ABC): + page_size = 2000 + transformer = TypeTransformer(TransformConfig.DefaultSchemaNormalization) + encoding = DEFAULT_ENCODING + + def __init__( + self, + sf_api: Salesforce, + pk: str, + stream_name: str, + sobject_options: Mapping[str, Any] = None, + schema: dict = None, + start_date=None, + **kwargs, + ): + super().__init__(**kwargs) + self.sf_api = sf_api + self.pk = pk + self.stream_name = stream_name + self.schema: Mapping[str, Any] = schema # type: ignore[assignment] + self.sobject_options = sobject_options + self.start_date = self.format_start_date(start_date) + + @staticmethod + def format_start_date(start_date: Optional[str]) -> Optional[str]: + """Transform the format `2021-07-25` into the format `2021-07-25T00:00:00Z`""" + if start_date: + return pendulum.parse(start_date).strftime("%Y-%m-%dT%H:%M:%SZ") # type: ignore[attr-defined,no-any-return] + return None + + @property + def max_properties_length(self) -> int: + return Salesforce.REQUEST_SIZE_LIMITS - len(self.url_base) - 2000 + + @property + def name(self) -> str: + return self.stream_name + + @property + def primary_key(self) -> Optional[Union[str, List[str], List[List[str]]]]: + return self.pk + + @property + def url_base(self) -> str: + return self.sf_api.instance_url + + @property + def availability_strategy(self) -> Optional["AvailabilityStrategy"]: + return AsyncSalesforceAvailabilityStrategy() + + @property + def too_many_properties(self): + selected_properties = self.get_json_schema().get("properties", {}) + properties_length = len(urllib.parse.quote(",".join(p for p in selected_properties))) + return properties_length > self.max_properties_length + + async def parse_response(self, response: aiohttp.ClientResponse, **kwargs) -> List[Mapping]: + for record in (await response.json())["records"]: + yield record + + def get_json_schema(self) -> Mapping[str, Any]: + if not self.schema: + self.schema = self.sf_api.generate_schema(self.name) + return self.schema + + def get_error_display_message(self, exception: BaseException) -> Optional[str]: + if isinstance(exception, HttpError): + return f"After {self.max_retries} retries the connector has failed with a network error. It looks like Salesforce API experienced temporary instability, please try again later." + return super().get_error_display_message(exception) + + +class PropertyChunk: + """ + Object that is used to keep track of the current state of a chunk of properties for the stream of records being synced. + """ + + properties: Mapping[str, Any] + first_time: bool + record_counter: int + next_page: Optional[Mapping[str, Any]] + + def __init__(self, properties: Mapping[str, Any]): + self.properties = properties + self.first_time = True + self.record_counter = 0 + self.next_page = None + + +class RestSalesforceStream(SalesforceStream): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + assert self.primary_key or not self.too_many_properties + + def path(self, next_page_token: Mapping[str, Any] = None, **kwargs: Any) -> str: + if next_page_token: + """ + If `next_page_token` is set, subsequent requests use `nextRecordsUrl`. + """ + next_token: str = next_page_token["next_token"] + return next_token + return f"/services/data/{self.sf_api.version}/queryAll" + + async def next_page_token(self, response: aiohttp.ClientResponse) -> Optional[Mapping[str, Any]]: + response_data = await response.json() + next_token = response_data.get("nextRecordsUrl") + return {"next_token": next_token} if next_token else None + + def request_params( + self, + stream_state: Mapping[str, Any], + stream_slice: Mapping[str, Any] = None, + next_page_token: Mapping[str, Any] = None, + property_chunk: Mapping[str, Any] = None, + ) -> MutableMapping[str, Any]: + """ + Salesforce SOQL Query: https://developer.salesforce.com/docs/atlas.en-us.232.0.api_rest.meta/api_rest/dome_queryall.htm + """ + if next_page_token: + # If `next_page_token` is set, subsequent requests use `nextRecordsUrl`, and do not include any parameters. + return {} + + property_chunk = property_chunk or {} + query = f"SELECT {','.join(property_chunk.keys())} FROM {self.name} " + + if self.name in PARENT_SALESFORCE_OBJECTS: + # add where clause: " WHERE ContentDocumentId IN ('06905000000NMXXXXX', ...)" + parent_field = PARENT_SALESFORCE_OBJECTS[self.name]["field"] + parent_ids = [f"'{parent_record[parent_field]}'" for parent_record in stream_slice["parents"]] + query += f" WHERE ContentDocumentId IN ({','.join(parent_ids)})" + + if self.primary_key and self.name not in UNSUPPORTED_FILTERING_STREAMS: + query += f"ORDER BY {self.primary_key} ASC" + + return {"q": query} + + def chunk_properties(self) -> Iterable[Mapping[str, Any]]: + selected_properties = self.get_json_schema().get("properties", {}) + + def empty_props_with_pk_if_present(): + return {self.primary_key: selected_properties[self.primary_key]} if self.primary_key else {} + + summary_length = 0 + local_properties = empty_props_with_pk_if_present() + for property_name, value in selected_properties.items(): + current_property_length = len(urllib.parse.quote(f"{property_name},")) + if current_property_length + summary_length >= self.max_properties_length: + yield local_properties + local_properties = empty_props_with_pk_if_present() + summary_length = 0 + + local_properties[property_name] = value + summary_length += current_property_length + + if local_properties: + yield local_properties + + @staticmethod + def _next_chunk_id(property_chunks: Mapping[int, PropertyChunk]) -> Optional[int]: + """ + Figure out which chunk is going to be read next. + It should be the one with the least number of records read by the moment. + """ + non_exhausted_chunks = { + # We skip chunks that have already attempted a sync before and do not have a next page + chunk_id: property_chunk.record_counter + for chunk_id, property_chunk in property_chunks.items() + if property_chunk.first_time or property_chunk.next_page + } + if not non_exhausted_chunks: + return None + return min(non_exhausted_chunks, key=non_exhausted_chunks.get) + + async def _read_pages( + self, + records_generator_fn: Callable[ + [aiohttp.ClientRequest, aiohttp.ClientResponse, Mapping[str, Any], Mapping[str, Any]], Iterable[StreamData] + ], + stream_slice: Mapping[str, Any] = None, + stream_state: Mapping[str, Any] = None, + ) -> None: + async for record in self._do_read_pages(records_generator_fn, stream_slice, stream_state): + yield record + + async def _do_read_pages( + self, + records_generator_fn: Callable[ + [aiohttp.ClientRequest, aiohttp.ClientResponse, Mapping[str, Any], Mapping[str, Any]], Iterable[StreamData] + ], + stream_slice: Mapping[str, Any] = None, + stream_state: Mapping[str, Any] = None, + ): + stream_state = stream_state or {} + records_by_primary_key = {} + property_chunks: Mapping[int, PropertyChunk] = { + index: PropertyChunk(properties=properties) for index, properties in enumerate(self.chunk_properties()) + } + while True: + chunk_id = self._next_chunk_id(property_chunks) + if chunk_id is None: + # pagination complete + break + + property_chunk = property_chunks[chunk_id] + + async def f(): + request, response = await self._fetch_next_page_for_chunk( + stream_slice, stream_state, property_chunk.next_page, property_chunk.properties + ) + next_page = await self.next_page_token(response) + return request, response, next_page + + request, response, property_chunk.next_page = await f() + + # When this is the first time we're getting a chunk's records, we set this to False to be used when deciding the next chunk + if property_chunk.first_time: + property_chunk.first_time = False + if not self.too_many_properties: + # this is the case when a stream has no primary key + # (it is allowed when properties length does not exceed the maximum value) + # so there would be a single chunk, therefore we may and should yield records immediately + async for record in records_generator_fn(request, response, stream_state, stream_slice): + property_chunk.record_counter += 1 + yield record + continue + + # stick together different parts of records by their primary key and emit if a record is complete + async for record in records_generator_fn(request, response, stream_state, stream_slice): + property_chunk.record_counter += 1 + record_id = record[self.primary_key] + if record_id not in records_by_primary_key: + records_by_primary_key[record_id] = (record, 1) + continue + partial_record, counter = records_by_primary_key[record_id] + partial_record.update(record) + counter += 1 + if counter == len(property_chunks): + yield partial_record # now it's complete + records_by_primary_key.pop(record_id) + else: + records_by_primary_key[record_id] = (partial_record, counter) + + # Process what's left. + # Because we make multiple calls to query N records (each call to fetch X properties of all the N records), + # there's a chance that the number of records corresponding to the query may change between the calls. + # Select 'a', 'b' from table order by pk -> returns records with ids `1`, `2` + # + # Select 'c', 'd' from table order by pk -> returns records with ids `1`, `3` + # Then records `2` and `3` would be incomplete. + # This may result in data inconsistency. We skip such records for now and log a warning message. + incomplete_record_ids = ",".join([str(key) for key in records_by_primary_key]) + if incomplete_record_ids: + self.logger.warning(f"Inconsistent record(s) with primary keys {incomplete_record_ids} found. Skipping them.") + + async def _fetch_next_page_for_chunk( + self, + stream_slice: Mapping[str, Any] = None, + stream_state: Mapping[str, Any] = None, + next_page_token: Mapping[str, Any] = None, + property_chunk: Mapping[str, Any] = None, + ) -> Tuple[aiohttp.ClientRequest, aiohttp.ClientResponse]: + request_headers = self.request_headers(stream_state=stream_state, stream_slice=stream_slice, next_page_token=next_page_token) + request = self._create_prepared_request( + path=self.path(stream_state=stream_state, stream_slice=stream_slice, next_page_token=next_page_token), + headers=dict(request_headers, **self.authenticator.get_auth_header()), + params=self.request_params( + stream_state=stream_state, stream_slice=stream_slice, next_page_token=next_page_token, property_chunk=property_chunk + ), + json=self.request_body_json(stream_state=stream_state, stream_slice=stream_slice, next_page_token=next_page_token), + data=self.request_body_data(stream_state=stream_state, stream_slice=stream_slice, next_page_token=next_page_token), + ) + request_kwargs = self.request_kwargs(stream_state=stream_state, stream_slice=stream_slice, next_page_token=next_page_token) + response = await self._send_request(request, request_kwargs) + return request, response + + +class BatchedSubStream(AsyncHttpSubStream): + SLICE_BATCH_SIZE = 200 + + async def stream_slices( + self, sync_mode: SyncMode, cursor_field: Optional[List[str]] = None, stream_state: Optional[Mapping[str, Any]] = None + ) -> Iterable[Optional[Mapping[str, Any]]]: + """Instead of yielding one parent record at a time, make stream slice contain a batch of parent records. + + It allows to get records by one requests (instead of only one). + """ + await self.ensure_session() # TODO: should this be self or super? + batched_slice = [] + async for stream_slice in super().stream_slices(sync_mode, cursor_field, stream_state): + if len(batched_slice) == self.SLICE_BATCH_SIZE: + yield {"parents": batched_slice} + batched_slice = [] + batched_slice.append(stream_slice["parent"]) + if batched_slice: + yield {"parents": batched_slice} + + +class RestSalesforceSubStream(BatchedSubStream, RestSalesforceStream): + pass + + +class BulkSalesforceStream(SalesforceStream): + DEFAULT_WAIT_TIMEOUT_SECONDS = 86400 # 24-hour bulk job running time + MAX_CHECK_INTERVAL_SECONDS = 2.0 + MAX_RETRY_NUMBER = 3 + + def path(self, next_page_token: Mapping[str, Any] = None, **kwargs: Any) -> str: + return f"/services/data/{self.sf_api.version}/jobs/query" + + transformer = TypeTransformer(TransformConfig.CustomSchemaNormalization | TransformConfig.DefaultSchemaNormalization) + + @default_backoff_handler(max_tries=5, factor=15) + async def _send_http_request(self, method: str, url: str, json: dict = None, headers: dict = None, stream=False) -> aiohttp.ClientResponse: # TODO: how to handle the stream argument + headers = self.authenticator.get_auth_header() if not headers else headers | self.authenticator.get_auth_header() + response = await self._session.request(method, url=url, headers=headers, json=json) + if response.status not in [200, 204]: + self.logger.error(f"error body: {await response.text()}, sobject options: {self.sobject_options}") + return await self.handle_response_with_error(response) + + async def create_stream_job(self, query: str, url: str) -> Optional[str]: + """ + docs: https://developer.salesforce.com/docs/atlas.en-us.api_asynch.meta/api_asynch/create_job.html + """ + json = {"operation": "queryAll", "query": query, "contentType": "CSV", "columnDelimiter": "COMMA", "lineEnding": "LF"} + try: + response = await self._send_http_request("POST", url, json=json) + job_id: str = (await response.json())["id"] + return job_id + except HttpError as error: # TODO: which errors? + if error.status_code in [codes.FORBIDDEN, codes.BAD_REQUEST]: + # A part of streams can't be used by BULK API. Every API version can have a custom list of + # these sobjects. Another part of them can be generated dynamically. That's why we can't track + # them preliminarily and there is only one way is to except error with necessary messages about + # their limitations. Now we know about 3 different reasons of similar errors: + # 1) some SaleForce sobjects(streams) is not supported by the BULK API simply (as is). + # 2) Access to a sobject(stream) is not available + # 3) sobject is not queryable. It means this sobject can't be called directly. + # We can call it as part of response from another sobject only. E.g.: + # initial query: "Select Id, Subject from ActivityHistory" -> error + # updated query: "Select Name, (Select Subject,ActivityType from ActivityHistories) from Contact" + # The second variant forces customisation for every case (ActivityHistory, ActivityHistories etc). + # And the main problem is these subqueries doesn't support CSV response format. + error_data = error.json() or {} + error_code = error_data.get("errorCode") + error_message = error_data.get("message", "") + if error_message == "Selecting compound data not supported in Bulk Query" or ( + error_code == "INVALIDENTITY" and "is not supported by the Bulk API" in error_message + ): + self.logger.error( + f"Cannot receive data for stream '{self.name}' using BULK API, " + f"sobject options: {self.sobject_options}, error message: '{error_message}'" + ) + elif error.status_code == codes.FORBIDDEN and error_code != "REQUEST_LIMIT_EXCEEDED": + self.logger.error( + f"Cannot receive data for stream '{self.name}' ," + f"sobject options: {self.sobject_options}, error message: '{error_message}'" + ) + elif error.status_code == codes.FORBIDDEN and error_code == "REQUEST_LIMIT_EXCEEDED": + self.logger.error( + f"Cannot receive data for stream '{self.name}' ," + f"sobject options: {self.sobject_options}, Error message: '{error_data.get('message')}'" + ) + elif error.status_code == codes.BAD_REQUEST and error_message.endswith("does not support query"): + self.logger.error( + f"The stream '{self.name}' is not queryable, " + f"sobject options: {self.sobject_options}, error message: '{error_message}'" + ) + elif ( + error.status_code == codes.BAD_REQUEST + and error_code == "API_ERROR" + and error_message.startswith("Implementation restriction") + ): + message = f"Unable to sync '{self.name}'. To prevent future syncs from failing, ensure the authenticated user has \"View all Data\" permissions." + raise AirbyteTracedException(message=message, failure_type=FailureType.config_error, exception=error) + elif error.status_code == codes.BAD_REQUEST and error_code == "LIMIT_EXCEEDED": + message = "Your API key for Salesforce has reached its limit for the 24-hour period. We will resume replication once the limit has elapsed." + self.logger.error(message) + else: + raise error + else: + raise error + return None + + async def wait_for_job(self, url: str) -> str: + expiration_time: DateTime = pendulum.now().add(seconds=self.DEFAULT_WAIT_TIMEOUT_SECONDS) + job_status = "InProgress" + delay_timeout = 0.0 + delay_cnt = 0 + job_info = None + # minimal starting delay is 0.5 seconds. + # this value was received empirically + time.sleep(0.5) + while pendulum.now() < expiration_time: + try: + job_info = await (await self._send_http_request("GET", url=url)).json() + except HttpError as error: + error_data = error.json() + error_code = error_data.get("errorCode") + error_message = error_data.get("message", "") + if ( + "We can't complete the action because enabled transaction security policies took too long to complete." in error_message + and error_code == "TXN_SECURITY_METERING_ERROR" + ): + message = 'A transient authentication error occurred. To prevent future syncs from failing, assign the "Exempt from Transaction Security" user permission to the authenticated user.' + raise AirbyteTracedException(message=message, failure_type=FailureType.config_error, exception=error) + else: + raise error + job_status = job_info["state"] + if job_status in ["JobComplete", "Aborted", "Failed"]: + if job_status != "JobComplete": + # this is only job metadata without payload + error_message = job_info.get("errorMessage") + if not error_message: + # not all failed response can have "errorMessage" and we need to show full response body + error_message = job_info + self.logger.error(f"JobStatus: {job_status}, sobject options: {self.sobject_options}, error message: '{error_message}'") + + return job_status + + if delay_timeout < self.MAX_CHECK_INTERVAL_SECONDS: + delay_timeout = 0.5 + math.exp(delay_cnt) / 1000.0 + delay_cnt += 1 + + time.sleep(delay_timeout) + job_id = job_info["id"] + self.logger.info( + f"Sleeping {delay_timeout} seconds while waiting for Job: {self.name}/{job_id} to complete. Current state: {job_status}" + ) + + self.logger.warning(f"Not wait the {self.name} data for {self.DEFAULT_WAIT_TIMEOUT_SECONDS} seconds, data: {job_info}!!") + return job_status + + async def execute_job(self, query: str, url: str) -> Tuple[Optional[str], Optional[str]]: + job_status = "Failed" + for i in range(0, self.MAX_RETRY_NUMBER): + job_id = await self.create_stream_job(query=query, url=url) + if not job_id: + return None, job_status + job_full_url = f"{url}/{job_id}" + job_status = await self.wait_for_job(url=job_full_url) + if job_status not in ["UploadComplete", "InProgress"]: + break + self.logger.error(f"Waiting error. Try to run this job again {i + 1}/{self.MAX_RETRY_NUMBER}...") + self.abort_job(url=job_full_url) + job_status = "Aborted" + + if job_status in ["Aborted", "Failed"]: + await self.delete_job(url=job_full_url) + return None, job_status + return job_full_url, job_status + + def filter_null_bytes(self, b: bytes): + """ + https://github.com/airbytehq/airbyte/issues/8300 + """ + res = b.replace(b"\x00", b"") + if len(res) < len(b): + self.logger.warning("Filter 'null' bytes from string, size reduced %d -> %d chars", len(b), len(res)) + return res + + def get_response_encoding(self, headers) -> str: + """Returns encodings from given HTTP Header Dict. + + :param headers: dictionary to extract encoding from. + :rtype: str + """ + + content_type = headers.get("content-type") + + if not content_type: + return self.encoding + + content_type, params = requests.utils._parse_content_type_header(content_type) + + if "charset" in params: + return params["charset"].strip("'\"") + + return self.encoding + + async def download_data(self, url: str, chunk_size: int = 1024) -> tuple[str, str, dict]: + """ + Retrieves binary data result from successfully `executed_job`, using chunks, to avoid local memory limitations. + @ url: string - the url of the `executed_job` + @ chunk_size: int - the buffer size for each chunk to fetch from stream, in bytes, default: 1024 bytes + Return the tuple containing string with file path of downloaded binary data (Saved temporarily) and file encoding. + """ + # set filepath for binary data from response + tmp_file = str(uuid.uuid4()) + response = await self._send_http_request("GET", url, headers={"Accept-Encoding": "gzip"}, stream=True) + with open( + tmp_file, "wb" + ) as data_file: + response_headers = response.headers + response_encoding = self.get_response_encoding(response_headers) + async for chunk in response.content.iter_chunked(chunk_size): + data_file.write(self.filter_null_bytes(chunk)) + # check the file exists + if os.path.isfile(tmp_file): + return tmp_file, response_encoding, response_headers + else: + raise TmpFileIOError(f"The IO/Error occured while verifying binary data. Stream: {self.name}, file {tmp_file} doesn't exist.") + + def read_with_chunks(self, path: str, file_encoding: str, chunk_size: int = 100) -> Iterable[Tuple[int, Mapping[str, Any]]]: + """ + Reads the downloaded binary data, using lines chunks, set by `chunk_size`. + @ path: string - the path to the downloaded temporarily binary data. + @ file_encoding: string - encoding for binary data file according to Standard Encodings from codecs module + @ chunk_size: int - the number of lines to read at a time, default: 100 lines / time. + """ + try: + with open(path, "r", encoding=file_encoding) as data: + chunks = pd.read_csv(data, chunksize=chunk_size, iterator=True, dialect="unix", dtype=object) + for chunk in chunks: + chunk = chunk.replace({nan: None}).to_dict(orient="records") + for row in chunk: + yield row + except pd.errors.EmptyDataError as e: + self.logger.info(f"Empty data received. {e}") + yield from [] + except IOError as ioe: + raise TmpFileIOError(f"The IO/Error occured while reading tmp data. Called: {path}. Stream: {self.name}", ioe) + finally: + # remove binary tmp file, after data is read + os.remove(path) + + def abort_job(self, url: str): + data = {"state": "Aborted"} + self._send_http_request("PATCH", url=url, json=data) + self.logger.warning("Broken job was aborted") + + async def delete_job(self, url: str): + await self._send_http_request("DELETE", url=url) + + @property + def availability_strategy(self) -> Optional["AvailabilityStrategy"]: + return None + + def next_page_token(self, last_record: Mapping[str, Any]) -> Optional[Mapping[str, Any]]: + return None + + def get_query_select_fields(self) -> str: + return ", ".join( + { + key: value + for key, value in self.get_json_schema().get("properties", {}).items() + if value.get("format") != "base64" and "object" not in value["type"] + } + ) + + def request_params( + self, stream_state: Mapping[str, Any], stream_slice: Mapping[str, Any] = None, next_page_token: Mapping[str, Any] = None + ) -> MutableMapping[str, Any]: + """ + Salesforce SOQL Query: https://developer.salesforce.com/docs/atlas.en-us.232.0.api_rest.meta/api_rest/dome_queryall.htm + """ + + select_fields = self.get_query_select_fields() + query = f"SELECT {select_fields} FROM {self.name}" + if next_page_token: + query += next_page_token["next_token"] + + if self.name in PARENT_SALESFORCE_OBJECTS: + # add where clause: " WHERE ContentDocumentId IN ('06905000000NMXXXXX', '06905000000Mxp7XXX', ...)" + parent_field = PARENT_SALESFORCE_OBJECTS[self.name]["field"] + parent_ids = [f"'{parent_record[parent_field]}'" for parent_record in stream_slice["parents"]] + query += f" WHERE ContentDocumentId IN ({','.join(parent_ids)})" + + return {"q": query} + + async def read_records( + self, + sync_mode: SyncMode, + cursor_field: List[str] = None, + stream_slice: Mapping[str, Any] = None, + stream_state: Mapping[str, Any] = None, + ) -> Iterable[Mapping[str, Any]]: + async for record in self._do_read_records(sync_mode, cursor_field, stream_slice, stream_state): + yield record + + async def _do_read_records( + self, + sync_mode: SyncMode, + cursor_field: List[str] = None, + stream_slice: Mapping[str, Any] = None, + stream_state: Mapping[str, Any] = None, + ): + stream_state = stream_state or {} + next_page_token = None + + params = self.request_params(stream_state=stream_state, stream_slice=stream_slice, next_page_token=next_page_token) + path = self.path(stream_state=stream_state, stream_slice=stream_slice, next_page_token=next_page_token) + job_full_url, job_status = await self.execute_job(query=params["q"], url=f"{self.url_base}{path}") + if not job_full_url: + if job_status == "Failed": + # As rule as BULK logic returns unhandled error. For instance: + # error message: 'Unexpected exception encountered in query processing. + # Please contact support with the following id: 326566388-63578 (-436445966)'" + # Thus we can try to switch to GET sync request because its response returns obvious error message + standard_instance = self.get_standard_instance() + self.logger.warning("switch to STANDARD(non-BULK) sync. Because the SalesForce BULK job has returned a failed status") + stream_is_available, error = standard_instance.check_availability(self.logger, None) + if not stream_is_available: + self.logger.warning(f"Skipped syncing stream '{standard_instance.name}' because it was unavailable. Error: {error}") + return + for record in standard_instance.read_records( + sync_mode=sync_mode, cursor_field=cursor_field, stream_slice=stream_slice, stream_state=stream_state + ): + yield record + return + raise SalesforceException(f"Job for {self.name} stream using BULK API was failed.") + salesforce_bulk_api_locator = None + while True: + req = PreparedRequest() + req.prepare_url(f"{job_full_url}/results", {"locator": salesforce_bulk_api_locator}) + tmp_file, response_encoding, response_headers = await self.download_data(url=req.url) + for record in self.read_with_chunks(tmp_file, response_encoding): + yield record + + if response_headers.get("Sforce-Locator", "null") == "null": + break + salesforce_bulk_api_locator = response_headers.get("Sforce-Locator") + + await self.delete_job(url=job_full_url) + + def get_standard_instance(self) -> SalesforceStream: + """Returns a instance of standard logic(non-BULK) with same settings""" + stream_kwargs = dict( + sf_api=self.sf_api, + pk=self.pk, + stream_name=self.stream_name, + schema=self.schema, + sobject_options=self.sobject_options, + authenticator=self.authenticator, + ) + new_cls: Type[SalesforceStream] = RestSalesforceStream + if isinstance(self, BulkIncrementalSalesforceStream): + stream_kwargs.update({"replication_key": self.replication_key, "start_date": self.start_date}) + new_cls = IncrementalRestSalesforceStream + + return new_cls(**stream_kwargs) + + +class BulkSalesforceSubStream(BatchedSubStream, BulkSalesforceStream): + pass + + +@BulkSalesforceStream.transformer.registerCustomTransform +def transform_empty_string_to_none(instance: Any, schema: Any): + """ + BULK API returns a `csv` file, where all values are initially as string type. + This custom transformer replaces empty lines with `None` value. + """ + if isinstance(instance, str) and not instance.strip(): + instance = None + + return instance + + +class IncrementalRestSalesforceStream(RestSalesforceStream, ABC): + state_checkpoint_interval = 500 + STREAM_SLICE_STEP = 30 + _slice = None + + def __init__(self, replication_key: str, **kwargs): + super().__init__(**kwargs) + self.replication_key = replication_key + + async def stream_slices( + self, *, sync_mode: SyncMode, cursor_field: List[str] = None, stream_state: Mapping[str, Any] = None + ) -> Iterable[Optional[Mapping[str, Any]]]: + start, end = (None, None) + now = pendulum.now(tz="UTC") + initial_date = pendulum.parse((stream_state or {}).get(self.cursor_field, self.start_date), tz="UTC") + + slice_number = 1 + while not end == now: + start = initial_date.add(days=(slice_number - 1) * self.STREAM_SLICE_STEP) + end = min(now, initial_date.add(days=slice_number * self.STREAM_SLICE_STEP)) + self._slice = {"start_date": start.isoformat(timespec="milliseconds"), "end_date": end.isoformat(timespec="milliseconds")} + yield {"start_date": start.isoformat(timespec="milliseconds"), "end_date": end.isoformat(timespec="milliseconds")} + slice_number = slice_number + 1 + + def request_params( + self, + stream_state: Mapping[str, Any], + stream_slice: Mapping[str, Any] = None, + next_page_token: Mapping[str, Any] = None, + property_chunk: Mapping[str, Any] = None, + ) -> MutableMapping[str, Any]: + if next_page_token: + """ + If `next_page_token` is set, subsequent requests use `nextRecordsUrl`, and do not include any parameters. + """ + return {} + + property_chunk = property_chunk or {} + + start_date = max( + (stream_state or {}).get(self.cursor_field, self.start_date), + (stream_slice or {}).get("start_date", ""), + (next_page_token or {}).get("start_date", ""), + ) + end_date = (stream_slice or {}).get("end_date", pendulum.now(tz="UTC").isoformat(timespec="milliseconds")) + + select_fields = ",".join(property_chunk.keys()) + table_name = self.name + where_conditions = [] + + if start_date: + where_conditions.append(f"{self.cursor_field} >= {start_date}") + if end_date: + where_conditions.append(f"{self.cursor_field} < {end_date}") + + where_clause = f"WHERE {' AND '.join(where_conditions)}" + query = f"SELECT {select_fields} FROM {table_name} {where_clause}" + + return {"q": query} + + @property + def cursor_field(self) -> str: + return self.replication_key + + def get_updated_state(self, current_stream_state: MutableMapping[str, Any], latest_record: Mapping[str, Any]) -> Mapping[str, Any]: + """ + Return the latest state by comparing the cursor value in the latest record with the stream's most recent state + object and returning an updated state object. Check if latest record is IN stream slice interval => ignore if not + """ + latest_record_value: pendulum.DateTime = pendulum.parse(latest_record[self.cursor_field]) + slice_max_value: pendulum.DateTime = pendulum.parse(self._slice.get("end_date")) + max_possible_value = min(latest_record_value, slice_max_value) + if current_stream_state.get(self.cursor_field): + if latest_record_value > slice_max_value: + return {self.cursor_field: max_possible_value.isoformat()} + max_possible_value = max(latest_record_value, pendulum.parse(current_stream_state[self.cursor_field])) + return {self.cursor_field: max_possible_value.isoformat()} + + +class BulkIncrementalSalesforceStream(BulkSalesforceStream, IncrementalRestSalesforceStream): + state_checkpoint_interval = None + + def request_params( + self, stream_state: Mapping[str, Any], stream_slice: Mapping[str, Any] = None, next_page_token: Mapping[str, Any] = None + ) -> MutableMapping[str, Any]: + start_date = stream_slice["start_date"] + end_date = stream_slice["end_date"] + + select_fields = self.get_query_select_fields() + table_name = self.name + where_conditions = [f"{self.cursor_field} >= {start_date}", f"{self.cursor_field} < {end_date}"] + + where_clause = f"WHERE {' AND '.join(where_conditions)}" + query = f"SELECT {select_fields} FROM {table_name} {where_clause}" + return {"q": query} + + +class Describe(Stream): + """ + Stream of sObjects' (Salesforce Objects) describe: + https://developer.salesforce.com/docs/atlas.en-us.api_rest.meta/api_rest/resources_sobject_describe.htm + """ + + name = "Describe" + primary_key = "name" + + def __init__(self, sf_api: Salesforce, catalog: ConfiguredAirbyteCatalog = None, **kwargs): + super().__init__(**kwargs) + self.sf_api = sf_api + if catalog: + self.sobjects_to_describe = [s.stream.name for s in catalog.streams if s.stream.name != self.name] + + def read_records(self, **kwargs) -> Iterable[Mapping[str, Any]]: + """ + Yield describe response of SObjects defined in catalog as streams only. + """ + for sobject in self.sobjects_to_describe: + yield self.sf_api.describe(sobject=sobject) diff --git a/airbyte-integrations/connectors/source-salesforce/unit_tests/api_test_async.py b/airbyte-integrations/connectors/source-salesforce/unit_tests/api_test_async.py new file mode 100644 index 000000000000..19176a71115b --- /dev/null +++ b/airbyte-integrations/connectors/source-salesforce/unit_tests/api_test_async.py @@ -0,0 +1,1073 @@ +# +# Copyright (c) 2023 Airbyte, Inc., all rights reserved. +# +import asyncio +import csv +import io +import logging +import re +from datetime import datetime +from typing import List +from unittest.mock import Mock +from yarl import URL + +import freezegun +import pendulum +import pytest +import requests_mock +from aioresponses import CallbackResult, aioresponses +from airbyte_cdk.models import AirbyteStream, ConfiguredAirbyteCatalog, ConfiguredAirbyteStream, DestinationSyncMode, SyncMode, Type +from airbyte_cdk.sources.async_cdk import source_dispatcher +from airbyte_cdk.sources.streams import Stream +from airbyte_cdk.sources.streams.concurrent.adapters import StreamFacade +from airbyte_cdk.sources.streams.http.utils import HttpError +from airbyte_cdk.utils import AirbyteTracedException +from conftest import encoding_symbols_parameters, generate_stream_async +from source_salesforce.api import Salesforce +from source_salesforce.exceptions import AUTHENTICATION_ERROR_MESSAGE_MAPPING +from source_salesforce.async_salesforce.source import SalesforceSourceDispatcher, AsyncSourceSalesforce +from source_salesforce.async_salesforce.streams import ( + CSV_FIELD_SIZE_LIMIT, + BulkIncrementalSalesforceStream, + BulkSalesforceStream, + BulkSalesforceSubStream, + IncrementalRestSalesforceStream, + RestSalesforceStream, +) + +_ANY_CATALOG = ConfiguredAirbyteCatalog.parse_obj({"streams": []}) +_ANY_CONFIG = {} + + +@pytest.mark.parametrize( + "login_status_code, login_json_resp, expected_error_msg, is_config_error", + [ + ( + 400, + {"error": "invalid_grant", "error_description": "expired access/refresh token"}, + AUTHENTICATION_ERROR_MESSAGE_MAPPING.get("expired access/refresh token"), + True, + ), + ( + 400, + {"error": "invalid_grant", "error_description": "Authentication failure."}, + 'An error occurred: {"error": "invalid_grant", "error_description": "Authentication failure."}', + False, + ), + ( + 401, + {"error": "Unauthorized", "error_description": "Unautorized"}, + 'An error occurred: {"error": "Unauthorized", "error_description": "Unautorized"}', + False, + ), + ], +) +def test_login_authentication_error_handler( + stream_config, requests_mock, login_status_code, login_json_resp, expected_error_msg, is_config_error +): + source = SalesforceSourceDispatcher(AsyncSourceSalesforce(_ANY_CATALOG, _ANY_CONFIG)) + logger = logging.getLogger("airbyte") + requests_mock.register_uri( + "POST", "https://login.salesforce.com/services/oauth2/token", json=login_json_resp, status_code=login_status_code + ) + + if is_config_error: + with pytest.raises(AirbyteTracedException) as err: + source.check_connection(logger, stream_config) + assert err.value.message == expected_error_msg + else: + result, msg = source.check_connection(logger, stream_config) + assert result is False + assert msg == expected_error_msg + + +@pytest.mark.asyncio +async def test_bulk_sync_creation_failed(stream_config, stream_api): + stream: BulkIncrementalSalesforceStream = await generate_stream_async("Account", stream_config, stream_api) + await stream.ensure_session() + + def callback(*args, **kwargs): + return CallbackResult(status=400, payload={"message": "test_error"}) + + with aioresponses() as m: + m.post("https://fase-account.salesforce.com/services/data/v57.0/jobs/query", status=400, callback=callback) + with pytest.raises(HttpError) as err: + stream_slices = await anext(stream.stream_slices(sync_mode=SyncMode.incremental)) + [r async for r in stream.read_records(sync_mode=SyncMode.full_refresh, stream_slice=stream_slices)] + + assert err.value.json()["message"] == "test_error" + await stream._session.close() + + +@pytest.mark.asyncio +async def test_bulk_stream_fallback_to_rest(stream_config, stream_api): + """ + Here we mock BULK API with response returning error, saying BULK is not supported for this kind of entity. + On the other hand, we mock REST API for this same entity with a successful response. + After having instantiated a BulkStream, sync should succeed in case it falls back to REST API. Otherwise it would throw an error. + """ + stream = await generate_stream_async("CustomEntity", stream_config, stream_api) + await stream.ensure_session() + + def callback(*args, **kwargs): + return CallbackResult(status=400, payload={"errorCode": "INVALIDENTITY", "message": "CustomEntity is not supported by the Bulk API"}, content_type="application/json") + + rest_stream_records = [ + {"id": 1, "name": "custom entity", "created": "2010-11-11"}, + {"id": 11, "name": "custom entity", "created": "2020-01-02"}, + ] + async def get_records(*args, **kwargs): + nonlocal rest_stream_records + for record in rest_stream_records: + yield record + + with aioresponses() as m: + # mock a BULK API + m.post("https://fase-account.salesforce.com/services/data/v57.0/jobs/query", status=400, callback=callback) + # mock REST API + stream.read_records = get_records + assert type(stream) is BulkIncrementalSalesforceStream + stream_slices = await anext(stream.stream_slices(sync_mode=SyncMode.incremental)) + assert [r async for r in stream.read_records(sync_mode=SyncMode.full_refresh, stream_slice=stream_slices)] == rest_stream_records + + await stream._session.close() + + +@pytest.mark.asyncio +async def test_stream_unsupported_by_bulk(stream_config, stream_api): + """ + Stream `AcceptedEventRelation` is not supported by BULK API, so that REST API stream will be used for it. + """ + stream_name = "AcceptedEventRelation" + stream = await generate_stream_async(stream_name, stream_config, stream_api) + assert not isinstance(stream, BulkSalesforceStream) + + +@pytest.mark.asyncio +async def test_stream_contains_unsupported_properties_by_bulk(stream_config, stream_api_v2): + """ + Stream `Account` contains compound field such as BillingAddress, which is not supported by BULK API (csv), + in that case REST API stream will be used for it. + """ + stream_name = "Account" + stream = await generate_stream_async(stream_name, stream_config, stream_api_v2) + assert not isinstance(stream, BulkSalesforceStream) + + +@pytest.mark.asyncio +async def test_bulk_sync_pagination(stream_config, stream_api): + stream: BulkIncrementalSalesforceStream = await generate_stream_async("Account", stream_config, stream_api) + await stream.ensure_session() + job_id = "fake_job" + call_counter = 0 + + def cb1(*args, **kwargs): + nonlocal call_counter + call_counter += 1 + return CallbackResult(headers={"Sforce-Locator": "somelocator_1"}, body="\n".join(resp_text)) + + def cb2(*args, **kwargs): + nonlocal call_counter + call_counter += 1 + return CallbackResult(headers={"Sforce-Locator": "somelocator_2"}, body="\n".join(resp_text)) + + def cb3(*args, **kwargs): + nonlocal call_counter + call_counter += 1 + return CallbackResult(headers={"Sforce-Locator": "null"}, body="\n".join(resp_text)) + + with aioresponses() as m: + base_url = f"{stream.sf_api.instance_url}{stream.path()}" + m.post(f"{base_url}", callback=lambda *args, **kwargs: CallbackResult(payload={"id": job_id})) + m.get(f"{base_url}/{job_id}", callback=lambda *args, **kwargs: CallbackResult(payload={"state": "JobComplete"})) + resp_text = ["Field1,LastModifiedDate,ID"] + [f"test,2021-11-16,{i}" for i in range(5)] + m.get(f"{base_url}/{job_id}/results", callback=cb1) + m.get(f"{base_url}/{job_id}/results?locator=somelocator_1", callback=cb2) + m.get(f"{base_url}/{job_id}/results?locator=somelocator_2", callback=cb3) + m.delete(base_url + f"/{job_id}") + + stream_slices = await anext(stream.stream_slices(sync_mode=SyncMode.incremental)) + loaded_ids = [int(record["ID"]) async for record in stream.read_records(sync_mode=SyncMode.full_refresh, stream_slice=stream_slices)] + assert loaded_ids == [0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4] + assert call_counter == 3 + await stream._session.close() + + + + +def _prepare_mock(m, stream): + job_id = "fake_job_1" + base_url = f"{stream.sf_api.instance_url}{stream.path()}" + m.post(base_url, callback=lambda *args, **kwargs: CallbackResult(payload={"id": job_id})) + m.delete(base_url + f"/{job_id}") + m.get(base_url + f"/{job_id}/results", callback=lambda *args, **kwargs: CallbackResult(body="Field1,LastModifiedDate,ID\ntest,2021-11-16,1")) + m.patch(base_url + f"/{job_id}", callback=lambda *args, **kwargs: CallbackResult(body="")) + return job_id + + +async def _get_result_id(stream): + stream_slices = await anext(stream.stream_slices(sync_mode=SyncMode.incremental)) + records = [r async for r in stream.read_records(sync_mode=SyncMode.full_refresh, stream_slice=stream_slices)] + return int(list(records)[0]["ID"]) + + +@pytest.mark.asyncio +async def test_bulk_sync_successful(stream_config, stream_api): + stream: BulkIncrementalSalesforceStream = await generate_stream_async("Account", stream_config, stream_api) + await stream.ensure_session() + base_url = f"{stream.sf_api.instance_url}{stream.path()}" + + with aioresponses() as m: + m.post(base_url, callback=lambda *args, **kwargs: CallbackResult(payload={"id": job_id})) + + with aioresponses() as m: + job_id = _prepare_mock(m, stream) + m.get(base_url + f"/{job_id}", callback=lambda *args, **kwargs: CallbackResult(payload={"state": "JobComplete"})) + assert await _get_result_id(stream) == 1 + + +@pytest.mark.asyncio +async def test_bulk_sync_successful_long_response(stream_config, stream_api): + stream: BulkIncrementalSalesforceStream = await generate_stream_async("Account", stream_config, stream_api) + await stream.ensure_session() + base_url = f"{stream.sf_api.instance_url}{stream.path()}" + + with aioresponses() as m: + job_id = _prepare_mock(m, stream) + m.get(base_url + f"/{job_id}", callback=lambda *args, **kwargs: CallbackResult(payload={"state": "UploadComplete", "id": job_id})) + m.get(base_url + f"/{job_id}", callback=lambda *args, **kwargs: CallbackResult(payload={"state": "InProgress", "id": job_id})) + m.get(base_url + f"/{job_id}", callback=lambda *args, **kwargs: CallbackResult(payload={"state": "JobComplete", "id": job_id})) + assert await _get_result_id(stream) == 1 + + +# maximum timeout is wait_timeout * max_retry_attempt +# this test tries to check a job state 17 times with +-1second for very one +@pytest.mark.asyncio +@pytest.mark.timeout(17) +async def test_bulk_sync_successful_retry(stream_config, stream_api): + stream: BulkIncrementalSalesforceStream = await generate_stream_async("Account", stream_config, stream_api) + stream.DEFAULT_WAIT_TIMEOUT_SECONDS = 6 # maximum wait timeout will be 6 seconds + await stream.ensure_session() + base_url = f"{stream.sf_api.instance_url}{stream.path()}" + + with aioresponses() as m: + job_id = _prepare_mock(m, stream) + # 2 failed attempts, 3rd one should be successful + states = [{"json": {"state": "InProgress", "id": job_id}}] * 17 + states.append({"json": {"state": "JobComplete", "id": job_id}}) + # raise Exception(states) + for _ in range(17): + m.get(base_url + f"/{job_id}", callback=lambda *args, **kwargs: CallbackResult(payload={"state": "InProgress", "id": job_id})) + m.get(base_url + f"/{job_id}", callback=lambda *args, **kwargs: CallbackResult(payload={"state": "JobComplete", "id": job_id})) + + assert await _get_result_id(stream) == 1 + + +@pytest.mark.asyncio +@pytest.mark.timeout(30) +async def test_bulk_sync_failed_retry(stream_config, stream_api): + stream: BulkIncrementalSalesforceStream = await generate_stream_async("Account", stream_config, stream_api) + stream.DEFAULT_WAIT_TIMEOUT_SECONDS = 6 # maximum wait timeout will be 6 seconds + await stream.ensure_session() + base_url = f"{stream.sf_api.instance_url}{stream.path()}" + + with aioresponses() as m: + job_id = _prepare_mock(m, stream) + m.get(base_url + f"/{job_id}", repeat=True, callback=lambda *args, **kwargs: CallbackResult(payload={"state": "InProgress", "id": job_id})) + m.post(base_url, repeat=True, callback=lambda *args, **kwargs: CallbackResult(payload={"id": job_id})) + with pytest.raises(Exception) as err: + stream_slices = await anext(stream.stream_slices(sync_mode=SyncMode.incremental)) + [record async for record in stream.read_records(sync_mode=SyncMode.full_refresh, stream_slice=stream_slices)] + assert "stream using BULK API was failed" in str(err.value) + + await stream._session.close() + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "start_date_provided,stream_name,expected_start_date", + [ + (True, "Account", "2010-01-18T21:18:20Z"), + (True, "ActiveFeatureLicenseMetric", "2010-01-18T21:18:20Z"), + ], +) +async def test_stream_start_date( + start_date_provided, + stream_name, + expected_start_date, + stream_config, + stream_api, + stream_config_without_start_date, +): + if start_date_provided: + stream = await generate_stream_async(stream_name, stream_config, stream_api) + assert stream.start_date == expected_start_date + else: + stream = await generate_stream_async(stream_name, stream_config_without_start_date, stream_api) + assert datetime.strptime(stream.start_date, "%Y-%m-%dT%H:%M:%SZ").year == datetime.now().year - 2 + + +@pytest.mark.asyncio +async def test_stream_start_date_should_be_converted_to_datetime_format(stream_config_date_format, stream_api): + stream: IncrementalRestSalesforceStream = await generate_stream_async("ActiveFeatureLicenseMetric", stream_config_date_format, stream_api) + assert stream.start_date == "2010-01-18T00:00:00Z" + + +@pytest.mark.asyncio +async def test_stream_start_datetime_format_should_not_changed(stream_config, stream_api): + stream: IncrementalRestSalesforceStream = await generate_stream_async("ActiveFeatureLicenseMetric", stream_config, stream_api) + assert stream.start_date == "2010-01-18T21:18:20Z" + + +@pytest.mark.asyncio +async def test_download_data_filter_null_bytes(stream_config, stream_api): + job_full_url_results: str = "https://fase-account.salesforce.com/services/data/v57.0/jobs/query/7504W00000bkgnpQAA/results" + stream: BulkIncrementalSalesforceStream = await generate_stream_async("Account", stream_config, stream_api) + await stream.ensure_session() + + with aioresponses() as m: + m.get(job_full_url_results, callback=lambda *args, **kwargs: CallbackResult(body=b"\x00")) + tmp_file, response_encoding, _ = await stream.download_data(url=job_full_url_results) + res = list(stream.read_with_chunks(tmp_file, response_encoding)) + assert res == [] + + m.get(job_full_url_results, callback=lambda *args, **kwargs: CallbackResult(body=b'"Id","IsDeleted"\n\x00"0014W000027f6UwQAI","false"\n\x00\x00')) + tmp_file, response_encoding, _ = await stream.download_data(url=job_full_url_results) + res = list(stream.read_with_chunks(tmp_file, response_encoding)) + assert res == [{"Id": "0014W000027f6UwQAI", "IsDeleted": "false"}] + + +@pytest.mark.asyncio +async def test_read_with_chunks_should_return_only_object_data_type(stream_config, stream_api): + job_full_url_results: str = "https://fase-account.salesforce.com/services/data/v57.0/jobs/query/7504W00000bkgnpQAA/results" + stream: BulkIncrementalSalesforceStream = await generate_stream_async("Account", stream_config, stream_api) + await stream.ensure_session() + + with aioresponses() as m: + m.get(job_full_url_results, callback=lambda *args, **kwargs: CallbackResult(body=b'"IsDeleted","Age"\n"false",24\n')) + tmp_file, response_encoding, _ = await stream.download_data(url=job_full_url_results) + res = list(stream.read_with_chunks(tmp_file, response_encoding)) + assert res == [{"IsDeleted": "false", "Age": "24"}] + + +@pytest.mark.asyncio +async def test_read_with_chunks_should_return_a_string_when_a_string_with_only_digits_is_provided(stream_config, stream_api): + job_full_url_results: str = "https://fase-account.salesforce.com/services/data/v57.0/jobs/query/7504W00000bkgnpQAA/results" + stream: BulkIncrementalSalesforceStream = await generate_stream_async("Account", stream_config, stream_api) + await stream.ensure_session() + + with aioresponses() as m: + m.get(job_full_url_results, body=b'"ZipCode"\n"01234"\n') + tmp_file, response_encoding, _ = await stream.download_data(url=job_full_url_results) + res = list(stream.read_with_chunks(tmp_file, response_encoding)) + assert res == [{"ZipCode": "01234"}] + + +@pytest.mark.asyncio +async def test_read_with_chunks_should_return_null_value_when_no_data_is_provided(stream_config, stream_api): + job_full_url_results: str = "https://fase-account.salesforce.com/services/data/v57.0/jobs/query/7504W00000bkgnpQAA/results" + stream: BulkIncrementalSalesforceStream = await generate_stream_async("Account", stream_config, stream_api) + await stream.ensure_session() + + with aioresponses() as m: + m.get(job_full_url_results, body=b'"IsDeleted","Age","Name"\n"false",,"Airbyte"\n') + tmp_file, response_encoding, _ = await stream.download_data(url=job_full_url_results) + res = list(stream.read_with_chunks(tmp_file, response_encoding)) + assert res == [{"IsDeleted": "false", "Age": None, "Name": "Airbyte"}] + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "chunk_size, content_type_header, content, expected_result", + encoding_symbols_parameters(), + ids=[f"charset: {x[1]}, chunk_size: {x[0]}" for x in encoding_symbols_parameters()], +) +async def test_encoding_symbols(stream_config, stream_api, chunk_size, content_type_header, content, expected_result): + job_full_url_results: str = "https://fase-account.salesforce.com/services/data/v57.0/jobs/query/7504W00000bkgnpQAA/results" + stream: BulkIncrementalSalesforceStream = await generate_stream_async("Account", stream_config, stream_api) + await stream.ensure_session() + + with aioresponses() as m: + m.get(job_full_url_results, headers=content_type_header, body=content) + tmp_file, response_encoding, _ = await stream.download_data(url=job_full_url_results) + res = list(stream.read_with_chunks(tmp_file, response_encoding)) + assert res == expected_result + + +@pytest.mark.parametrize( + "login_status_code, login_json_resp, discovery_status_code, discovery_resp_json, expected_error_msg", + ( + (403, [{"errorCode": "REQUEST_LIMIT_EXCEEDED", "message": "TotalRequests Limit exceeded."}], 200, {}, "API Call limit is exceeded"), + ( + 200, + {"access_token": "access_token", "instance_url": "https://instance_url"}, + 403, + [{"errorCode": "FORBIDDEN", "message": "You do not have enough permissions"}], + 'An error occurred: [{"errorCode": "FORBIDDEN", "message": "You do not have enough permissions"}]', + ), + ), +) +async def test_check_connection_rate_limit( + stream_config, login_status_code, login_json_resp, discovery_status_code, discovery_resp_json, expected_error_msg +): + source = AsyncSourceSalesforce(_ANY_CATALOG, _ANY_CONFIG) + logger = logging.getLogger("airbyte") + + with requests_mock.Mocker() as m: + m.register_uri("POST", "https://login.salesforce.com/services/oauth2/token", json=login_json_resp, status_code=login_status_code) + m.register_uri( + "GET", "https://instance_url/services/data/v57.0/sobjects", json=discovery_resp_json, status_code=discovery_status_code + ) + result, msg = source.check_connection(logger, stream_config) + assert result is False + assert msg == expected_error_msg + + +def configure_request_params_mock(stream_1, stream_2): + stream_1.request_params = Mock() + stream_1.request_params.return_value = {"q": "query"} + + stream_2.request_params = Mock() + stream_2.request_params.return_value = {"q": "query"} + + +def test_rate_limit_bulk(stream_config, stream_api, bulk_catalog, state): + """ + Connector should stop the sync if one stream reached rate limit + stream_1, stream_2, stream_3, ... + While reading `stream_1` if 403 (Rate Limit) is received, it should finish that stream with success and stop the sync process. + Next streams should not be executed. + """ + source_dispatcher.DEFAULT_SESSION_LIMIT = 1 # ensure that only one stream runs at a time + stream_config.update({"start_date": "2021-10-01"}) + loop = asyncio.get_event_loop() + stream_1: BulkIncrementalSalesforceStream = loop.run_until_complete(generate_stream_async("Account", stream_config, stream_api)) + stream_2: BulkIncrementalSalesforceStream = loop.run_until_complete(generate_stream_async("Asset", stream_config, stream_api)) + streams = [stream_1, stream_2] + configure_request_params_mock(stream_1, stream_2) + + stream_1.page_size = 6 + stream_1.state_checkpoint_interval = 5 + + source = SalesforceSourceDispatcher(AsyncSourceSalesforce(_ANY_CATALOG, _ANY_CONFIG)) + source.streams = Mock() + source.streams.return_value = streams + logger = logging.getLogger("airbyte") + + json_response = {"errorCode": "REQUEST_LIMIT_EXCEEDED", "message": "TotalRequests Limit exceeded."} + + orig_read_stream = source.async_source.read_stream + + async def patched_read_stream(*args, **kwargs): + base_url = f"{stream_1.sf_api.instance_url}{stream_1.path()}" + with aioresponses() as m: + creation_responses = [] + for page in [1, 2]: + job_id = f"fake_job_{page}_{stream_1.name}" + creation_responses.append({"id": job_id}) + + m.get(base_url + f"/{job_id}", callback=lambda *_, **__: CallbackResult(payload={"state": "JobComplete"})) + + resp = ["Field1,LastModifiedDate,Id"] + [f"test,2021-10-0{i},{i}" for i in range(1, 7)] # 6 records per page + + if page == 1: + # Read the first page successfully + m.get(base_url + f"/{job_id}/results", callback=lambda *_, **__: CallbackResult(body="\n".join(resp))) + else: + # Requesting for results when reading second page should fail with 403 (Rate Limit error) + m.get(base_url + f"/{job_id}/results", status=403, callback=lambda *_, **__: CallbackResult(status=403, payload=json_response)) + + m.delete(base_url + f"/{job_id}") + + def cb(response): + return lambda *_, **__: CallbackResult(payload=response) + + for response in creation_responses: + m.post(base_url, callback=cb(response)) + + async for r in orig_read_stream(**kwargs): + yield r + + source.async_source.read_stream = patched_read_stream + + result = [i for i in source.read(logger=logger, config=stream_config, catalog=bulk_catalog, state=state)] + assert stream_1.request_params.called + assert ( + not stream_2.request_params.called + ), "The second stream should not be executed, because the first stream finished with Rate Limit." + + records = [item for item in result if item.type == Type.RECORD] + assert len(records) == 6 # stream page size: 6 + + state_record = [item for item in result if item.type == Type.STATE][0] + assert state_record.state.data["Account"]["LastModifiedDate"] == "2021-10-05T00:00:00+00:00" # state checkpoint interval is 5. + + +@pytest.mark.asyncio +async def test_rate_limit_rest(stream_config, stream_api, rest_catalog, state): + source_dispatcher.DEFAULT_SESSION_LIMIT = 1 # ensure that only one stream runs at a time + stream_config.update({"start_date": "2021-11-01"}) + stream_1: IncrementalRestSalesforceStream = await generate_stream_async("KnowledgeArticle", stream_config, stream_api) + stream_2: IncrementalRestSalesforceStream = await generate_stream_async("AcceptedEventRelation", stream_config, stream_api) + stream_1.state_checkpoint_interval = 3 + configure_request_params_mock(stream_1, stream_2) + + source = SalesforceSourceDispatcher(AsyncSourceSalesforce(_ANY_CATALOG, _ANY_CONFIG)) + source.streams = Mock() + source.streams.return_value = [stream_1, stream_2] + + logger = logging.getLogger("airbyte") + + next_page_url = "/services/data/v57.0/query/012345" + response_1 = { + "done": False, + "totalSize": 10, + "nextRecordsUrl": next_page_url, + "records": [ + { + "ID": 1, + "LastModifiedDate": "2021-11-15", + }, + { + "ID": 2, + "LastModifiedDate": "2021-11-16", + }, + { + "ID": 3, + "LastModifiedDate": "2021-11-17", # check point interval + }, + { + "ID": 4, + "LastModifiedDate": "2021-11-18", + }, + { + "ID": 5, + "LastModifiedDate": "2021-11-19", + }, + ], + } + response_2 = {"errorCode": "REQUEST_LIMIT_EXCEEDED", "message": "TotalRequests Limit exceeded."} + + def cb1(*args, **kwargs): + return CallbackResult(payload=response_1, status=200) + + def cb2(*args, **kwargs): + return CallbackResult(payload=response_2, status=403, reason="") + + orig_read_records_s1 = stream_1.read_records + orig_read_records_s2 = stream_2.read_records + + async def patched_read_records_s1(*args, **kwargs): + with aioresponses() as m: + m.get(re.compile(re.escape(rf"{stream_1.sf_api.instance_url}{stream_1.path()}") + rf"\??.*"), repeat=True, callback=cb1) + m.get(re.compile(re.escape(rf"{stream_1.sf_api.instance_url}{next_page_url}") + rf"\??.*"), repeat=True, callback=cb2) + + async for r in orig_read_records_s1(**kwargs): + yield r + + async def patched_read_records_s2(*args, **kwargs): + with aioresponses() as m: + m.get(re.compile(re.escape(rf"{stream_2.sf_api.instance_url}{stream_2.path()}") + rf"\??.*"), repeat=True, callback=cb1) + m.get(re.compile(re.escape(rf"{stream_2.sf_api.instance_url}{next_page_url}") + rf"\??.*"), repeat=True, callback=cb1) + async for r in orig_read_records_s2(**kwargs): + yield r + + async def check_availability(*args, **kwargs): + return (True, None) + + stream_1.read_records = lambda *args, **kwargs: patched_read_records_s1(stream_1, *args, **kwargs) + stream_1.check_availability = check_availability + stream_2.read_records = lambda *args, **kwargs: patched_read_records_s2(stream_2, *args, **kwargs) + stream_2.check_availability = check_availability + + result = [i for i in source.read(logger=logger, config=stream_config, catalog=rest_catalog, state=state)] + + assert stream_1.request_params.called + assert ( + not stream_2.request_params.called + ), "The second stream should not be executed, because the first stream finished with Rate Limit." + + records = [item for item in result if item.type == Type.RECORD] + assert len(records) == 5 + + state_record = [item for item in result if item.type == Type.STATE][0] + assert state_record.state.data["KnowledgeArticle"]["LastModifiedDate"] == "2021-11-17T00:00:00+00:00" + + +@pytest.mark.asyncio +async def test_pagination_rest(stream_config, stream_api): + stream_name = "AcceptedEventRelation" + stream: RestSalesforceStream = await generate_stream_async(stream_name, stream_config, stream_api) + stream.DEFAULT_WAIT_TIMEOUT_SECONDS = 6 # maximum wait timeout will be 6 seconds + next_page_url = "/services/data/v57.0/query/012345" + await stream.ensure_session() + + resp_1 = { + "done": False, + "totalSize": 4, + "nextRecordsUrl": next_page_url, + "records": [ + { + "ID": 1, + "LastModifiedDate": "2021-11-15", + }, + { + "ID": 2, + "LastModifiedDate": "2021-11-16", + }, + ], + } + resp_2 = { + "done": True, + "totalSize": 4, + "records": [ + { + "ID": 3, + "LastModifiedDate": "2021-11-17", + }, + { + "ID": 4, + "LastModifiedDate": "2021-11-18", + }, + ], + } + + with aioresponses() as m: + m.get(re.compile(r"https://fase-account\.salesforce\.com/services/data/v57\.0\??.*"), callback=lambda *args, **kwargs: CallbackResult(payload=resp_1)) + m.get("https://fase-account.salesforce.com" + next_page_url, repeat=True, callback=lambda *args, **kwargs: CallbackResult(payload=resp_2)) + + records = [record async for record in stream.read_records(sync_mode=SyncMode.full_refresh)] + assert len(records) == 4 + + +@pytest.mark.asyncio +async def test_csv_reader_dialect_unix(): + stream: BulkSalesforceStream = BulkSalesforceStream(stream_name=None, sf_api=None, pk=None) + url_results = "https://fake-account.salesforce.com/services/data/v57.0/jobs/query/7504W00000bkgnpQAA/results" + await stream.ensure_session() + + data = [ + {"Id": "1", "Name": '"first_name" "last_name"'}, + {"Id": "2", "Name": "'" + 'first_name"\n' + "'" + 'last_name\n"'}, + {"Id": "3", "Name": "first_name last_name"}, + ] + + with io.StringIO("", newline="") as csvfile: + writer = csv.DictWriter(csvfile, fieldnames=["Id", "Name"], dialect="unix") + writer.writeheader() + for line in data: + writer.writerow(line) + text = csvfile.getvalue() + + with aioresponses() as m: + m.get(url_results, callback=lambda *args, **kwargs: CallbackResult(body=text)) + tmp_file, response_encoding, _ = await stream.download_data(url=url_results) + result = [i for i in stream.read_with_chunks(tmp_file, response_encoding)] + assert result == data + + +@pytest.mark.parametrize( + "stream_names,catalog_stream_names,", + ( + ( + ["stream_1", "stream_2", "Describe"], + None, + ), + ( + ["stream_1", "stream_2"], + ["stream_1", "stream_2", "Describe"], + ), + ( + ["stream_1", "stream_2", "stream_3", "Describe"], + ["stream_1", "Describe"], + ), + ), +) +async def test_forwarding_sobject_options(stream_config, stream_names, catalog_stream_names) -> None: + sobjects_matcher = re.compile("/sobjects$") + token_matcher = re.compile("/token$") + describe_matcher = re.compile("/describe$") + catalog = None + if catalog_stream_names: + catalog = ConfiguredAirbyteCatalog( + streams=[ + ConfiguredAirbyteStream( + stream=AirbyteStream( + name=catalog_stream_name, supported_sync_modes=[SyncMode.full_refresh], json_schema={"type": "object"} + ), + sync_mode=SyncMode.full_refresh, + destination_sync_mode=DestinationSyncMode.overwrite, + ) + for catalog_stream_name in catalog_stream_names + ] + ) + with requests_mock.Mocker() as m: + m.register_uri("POST", token_matcher, json={"instance_url": "https://fake-url.com", "access_token": "fake-token"}) + m.register_uri( + "GET", + describe_matcher, + json={ + "fields": [ + { + "name": "field", + "type": "string", + } + ] + }, + ) + m.register_uri( + "GET", + sobjects_matcher, + json={ + "sobjects": [ + { + "name": stream_name, + "flag1": True, + "queryable": True, + } + for stream_name in stream_names + if stream_name != "Describe" + ], + }, + ) + source = AsyncSourceSalesforce(_ANY_CATALOG, _ANY_CONFIG) + source.catalog = catalog + streams = source.streams(config=stream_config) + expected_names = catalog_stream_names if catalog else stream_names + assert not set(expected_names).symmetric_difference(set(stream.name for stream in streams)), "doesn't match excepted streams" + + for stream in streams: + if stream.name != "Describe": + if isinstance(stream, StreamFacade): + assert stream._legacy_stream.sobject_options == {"flag1": True, "queryable": True} + else: + assert stream.sobject_options == {"flag1": True, "queryable": True} + return + + +def _get_streams(stream_config, stream_names, catalog_stream_names, sync_type) -> List[Stream]: + sobjects_matcher = re.compile("/sobjects$") + token_matcher = re.compile("/token$") + describe_matcher = re.compile("/describe$") + catalog = None + if catalog_stream_names: + catalog = ConfiguredAirbyteCatalog( + streams=[ + ConfiguredAirbyteStream( + stream=AirbyteStream(name=catalog_stream_name, supported_sync_modes=[sync_type], json_schema={"type": "object"}), + sync_mode=sync_type, + destination_sync_mode=DestinationSyncMode.overwrite, + ) + for catalog_stream_name in catalog_stream_names + ] + ) + with requests_mock.Mocker() as m: + m.register_uri("POST", token_matcher, json={"instance_url": "https://fake-url.com", "access_token": "fake-token"}) + m.register_uri( + "GET", + describe_matcher, + json={ + "fields": [ + { + "name": "field", + "type": "string", + } + ] + }, + ) + m.register_uri( + "GET", + sobjects_matcher, + json={ + "sobjects": [ + { + "name": stream_name, + "flag1": True, + "queryable": True, + } + for stream_name in stream_names + if stream_name != "Describe" + ], + }, + ) + source = AsyncSourceSalesforce(_ANY_CATALOG, _ANY_CONFIG) + source.catalog = catalog + return source.streams(config=stream_config) + + +def test_csv_field_size_limit(): + DEFAULT_CSV_FIELD_SIZE_LIMIT = 1024 * 128 + + field_size = 1024 * 1024 + text = '"Id","Name"\n"1","' + field_size * "a" + '"\n' + + csv.field_size_limit(DEFAULT_CSV_FIELD_SIZE_LIMIT) + reader = csv.reader(io.StringIO(text)) + with pytest.raises(csv.Error): + for _ in reader: + pass + + csv.field_size_limit(CSV_FIELD_SIZE_LIMIT) + reader = csv.reader(io.StringIO(text)) + for _ in reader: + pass + + +@pytest.mark.asyncio +async def test_convert_to_standard_instance(stream_config, stream_api): + bulk_stream = await generate_stream_async("Account", stream_config, stream_api) + rest_stream = bulk_stream.get_standard_instance() + assert isinstance(rest_stream, IncrementalRestSalesforceStream) + + +@pytest.mark.asyncio +async def test_rest_stream_init_with_too_many_properties(stream_config, stream_api_v2_too_many_properties): + with pytest.raises(AssertionError): + # v2 means the stream is going to be a REST stream. + # A missing primary key is not allowed + await generate_stream_async("Account", stream_config, stream_api_v2_too_many_properties) + + +@pytest.mark.asyncio +async def test_too_many_properties(stream_config, stream_api_v2_pk_too_many_properties, requests_mock): + stream = await generate_stream_async("Account", stream_config, stream_api_v2_pk_too_many_properties) + await stream.ensure_session() + chunks = list(stream.chunk_properties()) + for chunk in chunks: + assert stream.primary_key in chunk + chunks_len = len(chunks) + assert stream.too_many_properties + assert stream.primary_key + assert type(stream) == RestSalesforceStream + next_page_url = "https://fase-account.salesforce.com/services/data/v57.0/queryAll" + url_pattern = re.compile(r"https://fase-account\.salesforce\.com/services/data/v57\.0/queryAll\??.*") + with aioresponses() as m: + m.get(url_pattern, callback=lambda *args, **kwargs: CallbackResult(payload={ + "records": [ + {"Id": 1, "propertyA": "A"}, + {"Id": 2, "propertyA": "A"}, + {"Id": 3, "propertyA": "A"}, + {"Id": 4, "propertyA": "A"}, + ] + })) + m.get(url_pattern, callback=lambda *args, **kwargs: CallbackResult(payload={"nextRecordsUrl": next_page_url, "records": [{"Id": 1, "propertyB": "B"}, {"Id": 2, "propertyB": "B"}]})) + # 2 for 2 chunks above + for _ in range(chunks_len - 2): + m.get(url_pattern, callback=lambda *args, **kwargs: CallbackResult(payload={"records": [{"Id": 1}, {"Id": 2}], "nextRecordsUrl": next_page_url})) + m.get(url_pattern, callback=lambda *args, **kwargs: CallbackResult(payload={"records": [{"Id": 3, "propertyB": "B"}, {"Id": 4, "propertyB": "B"}]})) + # 2 for 1 chunk above and 1 chunk had no next page + for _ in range(chunks_len - 2): + m.get(url_pattern, callback=lambda *args, **kwargs: CallbackResult(payload={"records": [{"Id": 3}, {"Id": 4}]})) + + records = [r async for r in stream.read_records(sync_mode=SyncMode.full_refresh)] + assert records == [ + {"Id": 1, "propertyA": "A", "propertyB": "B"}, + {"Id": 2, "propertyA": "A", "propertyB": "B"}, + {"Id": 3, "propertyA": "A", "propertyB": "B"}, + {"Id": 4, "propertyA": "A", "propertyB": "B"}, + ] + for call in requests_mock.request_history: + assert len(call.url) < Salesforce.REQUEST_SIZE_LIMITS + + +@pytest.mark.asyncio +async def test_stream_with_no_records_in_response(stream_config, stream_api_v2_pk_too_many_properties): + stream = await generate_stream_async("Account", stream_config, stream_api_v2_pk_too_many_properties) + chunks = list(stream.chunk_properties()) + for chunk in chunks: + assert stream.primary_key in chunk + assert stream.too_many_properties + assert stream.primary_key + assert type(stream) == RestSalesforceStream + url = re.compile(r"https://fase-account\.salesforce\.com/services/data/v57\.0/queryAll\??.*") + await stream.ensure_session() + + with aioresponses() as m: + m.get(url, repeat=True, callback=lambda *args, **kwargs: CallbackResult(payload={"records": []})) + records = [record async for record in stream.read_records(sync_mode=SyncMode.full_refresh)] + assert records == [] + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "status_code,response_json,log_message", + [ + ( + 400, + {"errorCode": "INVALIDENTITY", "message": "Account is not supported by the Bulk API"}, + "Account is not supported by the Bulk API", + ), + (403, {"errorCode": "REQUEST_LIMIT_EXCEEDED", "message": "API limit reached"}, "API limit reached"), + (400, {"errorCode": "API_ERROR", "message": "API does not support query"}, "The stream 'Account' is not queryable,"), + ( + 400, + {"errorCode": "LIMIT_EXCEEDED", "message": "Max bulk v2 query jobs (10000) per 24 hrs has been reached (10021)"}, + "Your API key for Salesforce has reached its limit for the 24-hour period. We will resume replication once the limit has elapsed.", + ), + ], +) +async def test_bulk_stream_error_in_logs_on_create_job(stream_config, stream_api, status_code, response_json, log_message, caplog): + stream = await generate_stream_async("Account", stream_config, stream_api) + await stream.ensure_session() + url = f"{stream.sf_api.instance_url}/services/data/{stream.sf_api.version}/jobs/query" + + with aioresponses() as m: + m.post(url, status=status_code, callback=lambda *args, **kwargs: CallbackResult(status=status_code, payload=response_json, reason="")) + query = "Select Id, Subject from Account" + with caplog.at_level(logging.ERROR): + assert await stream.create_stream_job(query, url) is None, "this stream should be skipped" + + # check logs + assert log_message in caplog.records[-1].message + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "status_code,response_json,error_message", + [ + ( + 400, + { + "errorCode": "TXN_SECURITY_METERING_ERROR", + "message": "We can't complete the action because enabled transaction security policies took too long to complete.", + }, + 'A transient authentication error occurred. To prevent future syncs from failing, assign the "Exempt from Transaction Security" user permission to the authenticated user.', + ), + ], +) +async def test_bulk_stream_error_on_wait_for_job(stream_config, stream_api, status_code, response_json, error_message): + stream = await generate_stream_async("Account", stream_config, stream_api) + await stream.ensure_session() + url = f"{stream.sf_api.instance_url}/services/data/{stream.sf_api.version}/jobs/query/queryJobId" + + with aioresponses() as m: + m.get(url, status=status_code, callback=lambda *args, **kwargs: CallbackResult(status=status_code, payload=response_json, reason="")) + with pytest.raises(AirbyteTracedException) as e: + await stream.wait_for_job(url=url) + assert e.value.message == error_message + + + +@pytest.mark.asyncio() +@freezegun.freeze_time("2023-01-01") +async def test_bulk_stream_slices(stream_config_date_format, stream_api): + stream: BulkIncrementalSalesforceStream = await generate_stream_async("FakeBulkStream", stream_config_date_format, stream_api) + stream_slices = [s async for s in stream.stream_slices(sync_mode=SyncMode.full_refresh)] + expected_slices = [] + today = pendulum.today(tz="UTC") + start_date = pendulum.parse(stream.start_date, tz="UTC") + while start_date < today: + expected_slices.append( + { + "start_date": start_date.isoformat(timespec="milliseconds"), + "end_date": min(today, start_date.add(days=stream.STREAM_SLICE_STEP)).isoformat(timespec="milliseconds"), + } + ) + start_date = start_date.add(days=stream.STREAM_SLICE_STEP) + assert expected_slices == stream_slices + + +@pytest.mark.asyncio +@freezegun.freeze_time("2023-04-01") +async def test_bulk_stream_request_params_states(stream_config_date_format, stream_api, bulk_catalog): + stream_config_date_format.update({"start_date": "2023-01-01"}) + stream: BulkIncrementalSalesforceStream = await generate_stream_async("Account", stream_config_date_format, stream_api) + await stream.ensure_session() + + source = SalesforceSourceDispatcher(AsyncSourceSalesforce(_ANY_CATALOG, _ANY_CONFIG)) + source.streams = Mock() + source.streams.return_value = [stream] + base_url = f"{stream.sf_api.instance_url}{stream.path()}" + + job_id_1 = "fake_job_1" + job_id_2 = "fake_job_2" + job_id_3 = "fake_job_3" + + with aioresponses() as m: + m.post(base_url, callback=lambda *args, **kwargs: CallbackResult(payload={"id": job_id_1})) + m.get(base_url + f"/{job_id_1}", callback=lambda *args, **kwargs: CallbackResult(payload={"state": "JobComplete"})) + m.delete(base_url + f"/{job_id_1}") + m.get(base_url + f"/{job_id_1}/results", + callback=lambda *args, **kwargs: CallbackResult(body="Field1,LastModifiedDate,ID\ntest,2023-01-15,1")) + m.patch(base_url + f"/{job_id_1}") + + m.post(base_url, callback=lambda *args, **kwargs: CallbackResult(payload={"id": job_id_2})) + m.get(base_url + f"/{job_id_2}", callback=lambda *args, **kwargs: CallbackResult(payload={"state": "JobComplete"})) + m.delete(base_url + f"/{job_id_2}") + m.get(base_url + f"/{job_id_2}/results", + callback=lambda *args, **kwargs: CallbackResult(body="Field1,LastModifiedDate,ID\ntest,2023-04-01,2\ntest,2023-02-20,22")) + m.patch(base_url + f"/{job_id_2}") + + m.post(base_url, callback=lambda *args, **kwargs: CallbackResult(payload={"id": job_id_3})) + m.get(base_url + f"/{job_id_3}", callback=lambda *args, **kwargs: CallbackResult(payload={"state": "JobComplete"})) + m.delete(base_url + f"/{job_id_3}") + m.get(base_url + f"/{job_id_3}/results", + callback=lambda *args, **kwargs: CallbackResult(body="Field1,LastModifiedDate,ID\ntest,2023-04-01,3")) + m.patch(base_url + f"/{job_id_3}") + + logger = logging.getLogger("airbyte") + state = {"Account": {"LastModifiedDate": "2023-01-01T10:10:10.000Z"}} + bulk_catalog.streams.pop(1) + result = [i for i in source.read(logger=logger, config=stream_config_date_format, catalog=bulk_catalog, state=state)] + + actual_state_values = [item.state.data.get("Account").get(stream.cursor_field) for item in result if item.type == Type.STATE] + queries_history = m.requests + + # assert request params + assert ( + "LastModifiedDate >= 2023-01-01T10:10:10.000+00:00 AND LastModifiedDate < 2023-01-31T10:10:10.000+00:00" + in queries_history[("POST", URL(base_url))][0].kwargs["json"]["query"] + ) + assert ( + "LastModifiedDate >= 2023-01-31T10:10:10.000+00:00 AND LastModifiedDate < 2023-03-02T10:10:10.000+00:00" + in queries_history[("POST", URL(base_url))][1].kwargs["json"]["query"] + ) + assert ( + "LastModifiedDate >= 2023-03-02T10:10:10.000+00:00 AND LastModifiedDate < 2023-04-01T00:00:00.000+00:00" + in queries_history[("POST", URL(base_url))][2].kwargs["json"]["query"] + ) + + # assert states + expected_state_values = ["2023-01-15T00:00:00+00:00", "2023-03-02T10:10:10+00:00", "2023-04-01T00:00:00+00:00"] + assert actual_state_values == expected_state_values + + +@pytest.mark.asyncio +async def test_request_params_incremental(stream_config_date_format, stream_api): + stream = await generate_stream_async("ContentDocument", stream_config_date_format, stream_api) + params = stream.request_params(stream_state={}, stream_slice={'start_date': '2020', 'end_date': '2021'}) + + assert params == {'q': 'SELECT LastModifiedDate, Id FROM ContentDocument WHERE LastModifiedDate >= 2020 AND LastModifiedDate < 2021'} + + +@pytest.mark.asyncio +async def test_request_params_substream(stream_config_date_format, stream_api): + stream = await generate_stream_async("ContentDocumentLink", stream_config_date_format, stream_api) + params = stream.request_params(stream_state={}, stream_slice={'parents': [{'Id': 1}, {'Id': 2}]}) + + assert params == {"q": "SELECT LastModifiedDate, Id FROM ContentDocumentLink WHERE ContentDocumentId IN ('1','2')"} + + +@pytest.mark.asyncio +@freezegun.freeze_time("2023-03-20") +async def test_stream_slices_for_substream(stream_config, stream_api): + stream_config['start_date'] = '2023-01-01' + stream: BulkSalesforceSubStream = await generate_stream_async("ContentDocumentLink", stream_config, stream_api) + stream.SLICE_BATCH_SIZE = 2 # each ContentDocumentLink should contain 2 records from parent ContentDocument stream + await stream.ensure_session() + + job_id = "fake_job" + base_url = f"{stream.sf_api.instance_url}{stream.path()}" + + with aioresponses() as m: + m.post(base_url, repeat=True, callback=lambda *args, **kwargs: CallbackResult(payload={"id": job_id})) + m.get(base_url + f"/{job_id}", repeat=True, callback=lambda *args, **kwargs: CallbackResult(payload={"state": "JobComplete"})) + m.get(base_url + f"/{job_id}/results", repeat=True, callback=lambda *args, **kwargs: CallbackResult(body="Field1,LastModifiedDate,ID\ntest,2021-11-16,123", headers={"Sforce-Locator": "null"})) + m.delete(base_url + f"/{job_id}", repeat=True, callback=lambda *args, **kwargs: CallbackResult()) + + stream_slices = [slice async for slice in stream.stream_slices(sync_mode=SyncMode.full_refresh)] + assert stream_slices == [ + {'parents': [{'Field1': 'test', 'ID': '123', 'LastModifiedDate': '2021-11-16'}, + {'Field1': 'test', 'ID': '123', 'LastModifiedDate': '2021-11-16'}]}, + {'parents': [{'Field1': 'test', 'ID': '123', 'LastModifiedDate': '2021-11-16'}]} + ] diff --git a/airbyte-integrations/connectors/source-salesforce/unit_tests/conftest.py b/airbyte-integrations/connectors/source-salesforce/unit_tests/conftest.py index eeacdd2235d2..da7ccbed50fc 100644 --- a/airbyte-integrations/connectors/source-salesforce/unit_tests/conftest.py +++ b/airbyte-integrations/connectors/source-salesforce/unit_tests/conftest.py @@ -8,6 +8,7 @@ import pytest from airbyte_cdk.models import ConfiguredAirbyteCatalog from source_salesforce.api import Salesforce +from source_salesforce.async_salesforce.source import AsyncSourceSalesforce from source_salesforce.source import SourceSalesforce @@ -120,7 +121,11 @@ def stream_api_v2_pk_too_many_properties(stream_config): def generate_stream(stream_name, stream_config, stream_api): - return SourceSalesforce.generate_streams(stream_config, {stream_name: None}, stream_api)[0] + return (SourceSalesforce.generate_streams(stream_config, {stream_name: None}, stream_api))[0] + + +async def generate_stream_async(stream_name, stream_config, stream_api): + return (await AsyncSourceSalesforce.generate_streams(stream_config, {stream_name: None}, stream_api))[0] def encoding_symbols_parameters(): diff --git a/airbyte-integrations/connectors/source-salesforce/unit_tests/test_memory_async.py b/airbyte-integrations/connectors/source-salesforce/unit_tests/test_memory_async.py new file mode 100644 index 000000000000..1c79fe5bc75a --- /dev/null +++ b/airbyte-integrations/connectors/source-salesforce/unit_tests/test_memory_async.py @@ -0,0 +1,53 @@ +# +# Copyright (c) 2023 Airbyte, Inc., all rights reserved. +# + + +import tracemalloc + +import pytest +from aioresponses import CallbackResult, aioresponses +from conftest import generate_stream_async +from source_salesforce.async_salesforce.streams import BulkIncrementalSalesforceStream + + +@pytest.mark.parametrize( + "n_records, first_size, first_peak", + ( + (1000, 0.4, 1), + (10000, 1, 2), + (100000, 4, 9), + (200000, 7, 19), + ), + ids=[ + "1k recods", + "10k records", + "100k records", + "200k records", + ], +) +@pytest.mark.asyncio +async def test_memory_download_data(stream_config, stream_api, n_records, first_size, first_peak): + job_full_url_results: str = "https://fase-account.salesforce.com/services/data/v57.0/jobs/query/7504W00000bkgnpQAA/results" + stream: BulkIncrementalSalesforceStream = await generate_stream_async("Account", stream_config, stream_api) + await stream.ensure_session() + content = b'"Id","IsDeleted"' + for _ in range(n_records): + content += b'"0014W000027f6UwQAI","false"\n' + + def callback(url, **kwargs): + return CallbackResult(body=content) + + with aioresponses() as m: + m.get(job_full_url_results, status=200, callback=callback) + tracemalloc.start() + tmp_file, response_encoding, _ = await stream.download_data(url=job_full_url_results) + for x in stream.read_with_chunks(tmp_file, response_encoding): + pass + fs, fp = tracemalloc.get_traced_memory() + first_size_in_mb, first_peak_in_mb = fs / 1024**2, fp / 1024**2 + + assert first_size_in_mb < first_size + assert first_peak_in_mb < first_peak + + await stream._session.close()