From eb31e4d2ba796f893e1740b5e7a298fd346400ef Mon Sep 17 00:00:00 2001 From: Catherine Noll Date: Mon, 29 Jan 2024 19:33:50 -0500 Subject: [PATCH] File-based CDK: make full refresh concurrent (#34411) --- .../concurrent_source_adapter.py | 6 +- .../availability_strategy/__init__.py | 7 +- ...stract_file_based_availability_strategy.py | 20 + .../sources/file_based/file_based_source.py | 88 ++++- .../file_based/file_types/csv_parser.py | 7 +- .../file_based/stream/concurrent/__init__.py | 0 .../file_based/stream/concurrent/adapters.py | 322 +++++++++++++++ .../file_based/stream/concurrent/cursor.py | 87 +++++ .../stream/default_file_based_stream.py | 8 +- .../concurrent/abstract_stream_facade.py | 37 ++ .../sources/streams/concurrent/adapters.py | 62 +-- .../sources/streams/concurrent/helpers.py | 31 ++ .../file_based/file_types/test_csv_parser.py | 18 +- .../file_based/in_memory_files_source.py | 14 +- .../file_based/scenarios/csv_scenarios.py | 30 +- .../scenarios/file_based_source_builder.py | 7 +- .../file_based/scenarios/jsonl_scenarios.py | 8 +- .../file_based/scenarios/parquet_scenarios.py | 1 + .../file_based/scenarios/scenario_builder.py | 13 +- .../scenarios/validation_policy_scenarios.py | 75 +--- .../stream/concurrent/test_adapters.py | 365 ++++++++++++++++++ .../sources/file_based/test_scenarios.py | 24 +- .../scenarios/stream_facade_builder.py | 5 +- ..._based_concurrent_stream_source_builder.py | 2 +- 24 files changed, 1042 insertions(+), 195 deletions(-) create mode 100644 airbyte-cdk/python/airbyte_cdk/sources/file_based/stream/concurrent/__init__.py create mode 100644 airbyte-cdk/python/airbyte_cdk/sources/file_based/stream/concurrent/adapters.py create mode 100644 airbyte-cdk/python/airbyte_cdk/sources/file_based/stream/concurrent/cursor.py create mode 100644 airbyte-cdk/python/airbyte_cdk/sources/streams/concurrent/abstract_stream_facade.py create mode 100644 airbyte-cdk/python/airbyte_cdk/sources/streams/concurrent/helpers.py create mode 100644 airbyte-cdk/python/unit_tests/sources/file_based/stream/concurrent/test_adapters.py diff --git a/airbyte-cdk/python/airbyte_cdk/sources/concurrent_source/concurrent_source_adapter.py b/airbyte-cdk/python/airbyte_cdk/sources/concurrent_source/concurrent_source_adapter.py index 8e2ea80b79ae..6c3b8aa70efb 100644 --- a/airbyte-cdk/python/airbyte_cdk/sources/concurrent_source/concurrent_source_adapter.py +++ b/airbyte-cdk/python/airbyte_cdk/sources/concurrent_source/concurrent_source_adapter.py @@ -10,7 +10,7 @@ from airbyte_cdk.sources.concurrent_source.concurrent_source import ConcurrentSource from airbyte_cdk.sources.streams import Stream from airbyte_cdk.sources.streams.concurrent.abstract_stream import AbstractStream -from airbyte_cdk.sources.streams.concurrent.adapters import StreamFacade +from airbyte_cdk.sources.streams.concurrent.abstract_stream_facade import AbstractStreamFacade class ConcurrentSourceAdapter(AbstractSource, ABC): @@ -58,6 +58,6 @@ def _select_abstract_streams(self, config: Mapping[str, Any], configured_catalog f"The stream {configured_stream.stream.name} no longer exists in the configuration. " f"Refresh the schema in replication settings and remove this stream from future sync attempts." ) - if isinstance(stream_instance, StreamFacade): - abstract_streams.append(stream_instance._abstract_stream) + if isinstance(stream_instance, AbstractStreamFacade): + abstract_streams.append(stream_instance.get_underlying_stream()) return abstract_streams diff --git a/airbyte-cdk/python/airbyte_cdk/sources/file_based/availability_strategy/__init__.py b/airbyte-cdk/python/airbyte_cdk/sources/file_based/availability_strategy/__init__.py index 983f4eeb8bf7..56204e9b74e6 100644 --- a/airbyte-cdk/python/airbyte_cdk/sources/file_based/availability_strategy/__init__.py +++ b/airbyte-cdk/python/airbyte_cdk/sources/file_based/availability_strategy/__init__.py @@ -1,4 +1,7 @@ -from .abstract_file_based_availability_strategy import AbstractFileBasedAvailabilityStrategy +from .abstract_file_based_availability_strategy import ( + AbstractFileBasedAvailabilityStrategy, + AbstractFileBasedAvailabilityStrategyWrapper, +) from .default_file_based_availability_strategy import DefaultFileBasedAvailabilityStrategy -__all__ = ["AbstractFileBasedAvailabilityStrategy", "DefaultFileBasedAvailabilityStrategy"] +__all__ = ["AbstractFileBasedAvailabilityStrategy", "AbstractFileBasedAvailabilityStrategyWrapper", "DefaultFileBasedAvailabilityStrategy"] diff --git a/airbyte-cdk/python/airbyte_cdk/sources/file_based/availability_strategy/abstract_file_based_availability_strategy.py b/airbyte-cdk/python/airbyte_cdk/sources/file_based/availability_strategy/abstract_file_based_availability_strategy.py index 1ba12f64febd..ba26745ea57c 100644 --- a/airbyte-cdk/python/airbyte_cdk/sources/file_based/availability_strategy/abstract_file_based_availability_strategy.py +++ b/airbyte-cdk/python/airbyte_cdk/sources/file_based/availability_strategy/abstract_file_based_availability_strategy.py @@ -8,6 +8,12 @@ from airbyte_cdk.sources import Source from airbyte_cdk.sources.streams.availability_strategy import AvailabilityStrategy +from airbyte_cdk.sources.streams.concurrent.availability_strategy import ( + AbstractAvailabilityStrategy, + StreamAvailability, + StreamAvailable, + StreamUnavailable, +) from airbyte_cdk.sources.streams.core import Stream if TYPE_CHECKING: @@ -35,3 +41,17 @@ def check_availability_and_parsability( Returns (True, None) if successful, otherwise (False, ). """ ... + + +class AbstractFileBasedAvailabilityStrategyWrapper(AbstractAvailabilityStrategy): + def __init__(self, stream: "AbstractFileBasedStream"): + self.stream = stream + + def check_availability(self, logger: logging.Logger) -> StreamAvailability: + is_available, reason = self.stream.availability_strategy.check_availability(self.stream, logger, None) + if is_available: + return StreamAvailable() + return StreamUnavailable(reason or "") + + def check_availability_and_parsability(self, logger: logging.Logger) -> Tuple[bool, Optional[str]]: + return self.stream.availability_strategy.check_availability_and_parsability(self.stream, logger, None) diff --git a/airbyte-cdk/python/airbyte_cdk/sources/file_based/file_based_source.py b/airbyte-cdk/python/airbyte_cdk/sources/file_based/file_based_source.py index 9904e4a8be97..cfdc0cdcedbd 100644 --- a/airbyte-cdk/python/airbyte_cdk/sources/file_based/file_based_source.py +++ b/airbyte-cdk/python/airbyte_cdk/sources/file_based/file_based_source.py @@ -8,8 +8,18 @@ from collections import Counter from typing import Any, Iterator, List, Mapping, MutableMapping, Optional, Tuple, Type, Union -from airbyte_cdk.models import AirbyteMessage, AirbyteStateMessage, ConfiguredAirbyteCatalog, ConnectorSpecification -from airbyte_cdk.sources import AbstractSource +from airbyte_cdk.logger import AirbyteLogFormatter, init_logger +from airbyte_cdk.models import ( + AirbyteMessage, + AirbyteStateMessage, + ConfiguredAirbyteCatalog, + ConnectorSpecification, + FailureType, + Level, + SyncMode, +) +from airbyte_cdk.sources.concurrent_source.concurrent_source import ConcurrentSource +from airbyte_cdk.sources.concurrent_source.concurrent_source_adapter import ConcurrentSourceAdapter from airbyte_cdk.sources.file_based.availability_strategy import AbstractFileBasedAvailabilityStrategy, DefaultFileBasedAvailabilityStrategy from airbyte_cdk.sources.file_based.config.abstract_file_based_spec import AbstractFileBasedSpec from airbyte_cdk.sources.file_based.config.file_based_stream_config import FileBasedStreamConfig, ValidationPolicy @@ -20,19 +30,33 @@ from airbyte_cdk.sources.file_based.file_types.file_type_parser import FileTypeParser from airbyte_cdk.sources.file_based.schema_validation_policies import DEFAULT_SCHEMA_VALIDATION_POLICIES, AbstractSchemaValidationPolicy from airbyte_cdk.sources.file_based.stream import AbstractFileBasedStream, DefaultFileBasedStream +from airbyte_cdk.sources.file_based.stream.concurrent.adapters import FileBasedStreamFacade +from airbyte_cdk.sources.file_based.stream.concurrent.cursor import FileBasedNoopCursor from airbyte_cdk.sources.file_based.stream.cursor import AbstractFileBasedCursor from airbyte_cdk.sources.file_based.stream.cursor.default_file_based_cursor import DefaultFileBasedCursor +from airbyte_cdk.sources.message.repository import InMemoryMessageRepository, MessageRepository +from airbyte_cdk.sources.source import TState from airbyte_cdk.sources.streams import Stream from airbyte_cdk.utils.analytics_message import create_analytics_message +from airbyte_cdk.utils.traced_exception import AirbyteTracedException from pydantic.error_wrappers import ValidationError +DEFAULT_CONCURRENCY = 100 +MAX_CONCURRENCY = 100 +INITIAL_N_PARTITIONS = MAX_CONCURRENCY // 2 + + +class FileBasedSource(ConcurrentSourceAdapter, ABC): + # We make each source override the concurrency level to give control over when they are upgraded. + _concurrency_level = None -class FileBasedSource(AbstractSource, ABC): def __init__( self, stream_reader: AbstractFileBasedStreamReader, spec_class: Type[AbstractFileBasedSpec], - catalog_path: Optional[str] = None, + catalog: Optional[ConfiguredAirbyteCatalog], + config: Optional[Mapping[str, Any]], + state: Optional[TState], availability_strategy: Optional[AbstractFileBasedAvailabilityStrategy] = None, discovery_policy: AbstractDiscoveryPolicy = DefaultDiscoveryPolicy(), parsers: Mapping[Type[Any], FileTypeParser] = default_parsers, @@ -41,15 +65,29 @@ def __init__( ): self.stream_reader = stream_reader self.spec_class = spec_class + self.config = config + self.catalog = catalog + self.state = state self.availability_strategy = availability_strategy or DefaultFileBasedAvailabilityStrategy(stream_reader) self.discovery_policy = discovery_policy self.parsers = parsers self.validation_policies = validation_policies - catalog = self.read_catalog(catalog_path) if catalog_path else None self.stream_schemas = {s.stream.name: s.stream.json_schema for s in catalog.streams} if catalog else {} self.cursor_cls = cursor_cls - self.logger = logging.getLogger(f"airbyte.{self.name}") + self.logger = init_logger(f"airbyte.{self.name}") self.errors_collector: FileBasedErrorsCollector = FileBasedErrorsCollector() + self._message_repository: Optional[MessageRepository] = None + concurrent_source = ConcurrentSource.create( + MAX_CONCURRENCY, INITIAL_N_PARTITIONS, self.logger, self._slice_logger, self.message_repository + ) + self._state = None + super().__init__(concurrent_source) + + @property + def message_repository(self) -> MessageRepository: + if self._message_repository is None: + self._message_repository = InMemoryMessageRepository(Level(AirbyteLogFormatter.level_mapping[self.logger.level])) + return self._message_repository def check_connection(self, logger: logging.Logger, config: Mapping[str, Any]) -> Tuple[bool, Optional[Any]]: """ @@ -61,7 +99,15 @@ def check_connection(self, logger: logging.Logger, config: Mapping[str, Any]) -> Otherwise, the "error" object should describe what went wrong. """ - streams = self.streams(config) + try: + streams = self.streams(config) + except Exception as config_exception: + raise AirbyteTracedException( + internal_message="Please check the logged errors for more information.", + message=FileBasedSourceError.CONFIG_VALIDATION_ERROR.value, + exception=AirbyteTracedException(exception=config_exception), + failure_type=FailureType.config_error, + ) if len(streams) == 0: return ( False, @@ -80,7 +126,7 @@ def check_connection(self, logger: logging.Logger, config: Mapping[str, Any]) -> reason, ) = stream.availability_strategy.check_availability_and_parsability(stream, logger, self) except Exception: - errors.append(f"Unable to connect to stream {stream} - {''.join(traceback.format_exc())}") + errors.append(f"Unable to connect to stream {stream.name} - {''.join(traceback.format_exc())}") else: if not stream_is_available and reason: errors.append(reason) @@ -91,10 +137,26 @@ def streams(self, config: Mapping[str, Any]) -> List[Stream]: """ Return a list of this source's streams. """ + file_based_streams = self._get_file_based_streams(config) + + configured_streams: List[Stream] = [] + + for stream in file_based_streams: + sync_mode = self._get_sync_mode_from_catalog(stream) + if sync_mode == SyncMode.full_refresh and hasattr(self, "_concurrency_level") and self._concurrency_level is not None: + configured_streams.append( + FileBasedStreamFacade.create_from_stream(stream, self, self.logger, None, FileBasedNoopCursor(stream.config)) + ) + else: + configured_streams.append(stream) + + return configured_streams + + def _get_file_based_streams(self, config: Mapping[str, Any]) -> List[AbstractFileBasedStream]: try: parsed_config = self._get_parsed_config(config) self.stream_reader.config = parsed_config - streams: List[Stream] = [] + streams: List[AbstractFileBasedStream] = [] for stream_config in parsed_config.streams: self._validate_input_schema(stream_config) streams.append( @@ -115,6 +177,14 @@ def streams(self, config: Mapping[str, Any]) -> List[Stream]: except ValidationError as exc: raise ConfigValidationError(FileBasedSourceError.CONFIG_VALIDATION_ERROR) from exc + def _get_sync_mode_from_catalog(self, stream: Stream) -> Optional[SyncMode]: + if self.catalog: + for catalog_stream in self.catalog.streams: + if stream.name == catalog_stream.stream.name: + return catalog_stream.sync_mode + raise RuntimeError(f"No sync mode was found for {stream.name}.") + return None + def read( self, logger: logging.Logger, diff --git a/airbyte-cdk/python/airbyte_cdk/sources/file_based/file_types/csv_parser.py b/airbyte-cdk/python/airbyte_cdk/sources/file_based/file_types/csv_parser.py index b67aebcd723e..627c3573b669 100644 --- a/airbyte-cdk/python/airbyte_cdk/sources/file_based/file_types/csv_parser.py +++ b/airbyte-cdk/python/airbyte_cdk/sources/file_based/file_types/csv_parser.py @@ -10,6 +10,7 @@ from functools import partial from io import IOBase from typing import Any, Callable, Dict, Generator, Iterable, List, Mapping, Optional, Set, Tuple +from uuid import uuid4 from airbyte_cdk.models import FailureType from airbyte_cdk.sources.file_based.config.csv_format import CsvFormat, CsvHeaderAutogenerated, CsvHeaderUserProvided, InferenceType @@ -38,8 +39,10 @@ def read_data( # Formats are configured individually per-stream so a unique dialect should be registered for each stream. # We don't unregister the dialect because we are lazily parsing each csv file to generate records - # This will potentially be a problem if we ever process multiple streams concurrently - dialect_name = config.name + DIALECT_NAME + # Give each stream's dialect a unique name; otherwise, when we are doing a concurrent sync we can end up + # with a race condition where a thread attempts to use a dialect before a separate thread has finished + # registering it. + dialect_name = f"{config.name}_{str(uuid4())}_{DIALECT_NAME}" csv.register_dialect( dialect_name, delimiter=config_format.delimiter, diff --git a/airbyte-cdk/python/airbyte_cdk/sources/file_based/stream/concurrent/__init__.py b/airbyte-cdk/python/airbyte_cdk/sources/file_based/stream/concurrent/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/airbyte-cdk/python/airbyte_cdk/sources/file_based/stream/concurrent/adapters.py b/airbyte-cdk/python/airbyte_cdk/sources/file_based/stream/concurrent/adapters.py new file mode 100644 index 000000000000..731b04621705 --- /dev/null +++ b/airbyte-cdk/python/airbyte_cdk/sources/file_based/stream/concurrent/adapters.py @@ -0,0 +1,322 @@ +# +# Copyright (c) 2024 Airbyte, Inc., all rights reserved. +# + +import copy +import logging +from functools import lru_cache +from typing import Any, Iterable, List, Mapping, MutableMapping, Optional, Union + +from airbyte_cdk.models import AirbyteLogMessage, AirbyteMessage, Level, SyncMode, Type +from airbyte_cdk.sources import AbstractSource +from airbyte_cdk.sources.connector_state_manager import ConnectorStateManager +from airbyte_cdk.sources.file_based.availability_strategy import ( + AbstractFileBasedAvailabilityStrategy, + AbstractFileBasedAvailabilityStrategyWrapper, +) +from airbyte_cdk.sources.file_based.config.file_based_stream_config import PrimaryKeyType +from airbyte_cdk.sources.file_based.file_types.file_type_parser import FileTypeParser +from airbyte_cdk.sources.file_based.remote_file import RemoteFile +from airbyte_cdk.sources.file_based.stream import AbstractFileBasedStream +from airbyte_cdk.sources.file_based.stream.concurrent.cursor import FileBasedNoopCursor +from airbyte_cdk.sources.file_based.types import StreamSlice +from airbyte_cdk.sources.message import MessageRepository +from airbyte_cdk.sources.streams.concurrent.abstract_stream_facade import AbstractStreamFacade +from airbyte_cdk.sources.streams.concurrent.default_stream import DefaultStream +from airbyte_cdk.sources.streams.concurrent.exceptions import ExceptionWithDisplayMessage +from airbyte_cdk.sources.streams.concurrent.helpers import get_cursor_field_from_stream, get_primary_key_from_stream +from airbyte_cdk.sources.streams.concurrent.partitions.partition import Partition +from airbyte_cdk.sources.streams.concurrent.partitions.partition_generator import PartitionGenerator +from airbyte_cdk.sources.streams.concurrent.partitions.record import Record +from airbyte_cdk.sources.streams.core import StreamData +from airbyte_cdk.sources.utils.schema_helpers import InternalConfig +from airbyte_cdk.sources.utils.slice_logger import SliceLogger +from deprecated.classic import deprecated + +""" +This module contains adapters to help enabling concurrency on File-based Stream objects without needing to migrate to AbstractStream +""" + + +@deprecated("This class is experimental. Use at your own risk.") +class FileBasedStreamFacade(AbstractStreamFacade[DefaultStream], AbstractFileBasedStream): + @classmethod + def create_from_stream( + cls, + stream: AbstractFileBasedStream, + source: AbstractSource, + logger: logging.Logger, + state: Optional[MutableMapping[str, Any]], + cursor: FileBasedNoopCursor, + ) -> "FileBasedStreamFacade": + """ + Create a ConcurrentStream from a FileBasedStream object. + """ + pk = get_primary_key_from_stream(stream.primary_key) + cursor_field = get_cursor_field_from_stream(stream) + + if not source.message_repository: + raise ValueError( + "A message repository is required to emit non-record messages. Please set the message repository on the source." + ) + + message_repository = source.message_repository + return FileBasedStreamFacade( + DefaultStream( # type: ignore + partition_generator=FileBasedStreamPartitionGenerator( + stream, + message_repository, + SyncMode.full_refresh if isinstance(cursor, FileBasedNoopCursor) else SyncMode.incremental, + [cursor_field] if cursor_field is not None else None, + state, + cursor, + ), + name=stream.name, + json_schema=stream.get_json_schema(), + availability_strategy=AbstractFileBasedAvailabilityStrategyWrapper(stream), + primary_key=pk, + cursor_field=cursor_field, + logger=logger, + namespace=stream.namespace, + ), + stream, + cursor, + logger=logger, + slice_logger=source._slice_logger, + ) + + def __init__( + self, + stream: DefaultStream, + legacy_stream: AbstractFileBasedStream, + cursor: FileBasedNoopCursor, + slice_logger: SliceLogger, + logger: logging.Logger, + ): + """ + :param stream: The underlying AbstractStream + """ + # super().__init__(stream, legacy_stream, cursor, slice_logger, logger) + self._abstract_stream = stream + self._legacy_stream = legacy_stream + self._cursor = cursor + self._slice_logger = slice_logger + self._logger = logger + self.catalog_schema = legacy_stream.catalog_schema + self.config = legacy_stream.config + self.validation_policy = legacy_stream.validation_policy + + @property + def cursor_field(self) -> Union[str, List[str]]: + if self._abstract_stream.cursor_field is None: + return [] + else: + return self._abstract_stream.cursor_field + + @property + def name(self) -> str: + return self._abstract_stream.name + + @property + def supports_incremental(self) -> bool: + return self._legacy_stream.supports_incremental + + @property + def availability_strategy(self) -> AbstractFileBasedAvailabilityStrategy: + return self._legacy_stream.availability_strategy + + @lru_cache(maxsize=None) + def get_json_schema(self) -> Mapping[str, Any]: + return self._abstract_stream.get_json_schema() + + @property + def primary_key(self) -> PrimaryKeyType: + return self._legacy_stream.config.primary_key or self.get_parser().get_parser_defined_primary_key(self._legacy_stream.config) + + def get_parser(self) -> FileTypeParser: + return self._legacy_stream.get_parser() + + def get_files(self) -> Iterable[RemoteFile]: + return self._legacy_stream.get_files() + + def read_records_from_slice(self, stream_slice: StreamSlice) -> Iterable[Mapping[str, Any]]: + yield from self._legacy_stream.read_records_from_slice(stream_slice) + + def compute_slices(self) -> Iterable[Optional[StreamSlice]]: + return self._legacy_stream.compute_slices() + + def infer_schema(self, files: List[RemoteFile]) -> Mapping[str, Any]: + return self._legacy_stream.infer_schema(files) + + def get_underlying_stream(self) -> DefaultStream: + return self._abstract_stream + + def read_full_refresh( + self, + cursor_field: Optional[List[str]], + logger: logging.Logger, + slice_logger: SliceLogger, + ) -> Iterable[StreamData]: + """ + Read full refresh. Delegate to the underlying AbstractStream, ignoring all the parameters + :param cursor_field: (ignored) + :param logger: (ignored) + :param slice_logger: (ignored) + :return: Iterable of StreamData + """ + yield from self._read_records() + + def read_incremental( + self, + cursor_field: Optional[List[str]], + logger: logging.Logger, + slice_logger: SliceLogger, + stream_state: MutableMapping[str, Any], + state_manager: ConnectorStateManager, + per_stream_state_enabled: bool, + internal_config: InternalConfig, + ) -> Iterable[StreamData]: + yield from self._read_records() + + def read_records( + self, + sync_mode: SyncMode, + cursor_field: Optional[List[str]] = None, + stream_slice: Optional[Mapping[str, Any]] = None, + stream_state: Optional[Mapping[str, Any]] = None, + ) -> Iterable[StreamData]: + try: + yield from self._read_records() + except Exception as exc: + if hasattr(self._cursor, "state"): + state = str(self._cursor.state) + else: + # This shouldn't happen if the ConcurrentCursor was used + state = "unknown; no state attribute was available on the cursor" + yield AirbyteMessage( + type=Type.LOG, log=AirbyteLogMessage(level=Level.ERROR, message=f"Cursor State at time of exception: {state}") + ) + raise exc + + def _read_records(self) -> Iterable[StreamData]: + for partition in self._abstract_stream.generate_partitions(): + if self._slice_logger.should_log_slice_message(self._logger): + yield self._slice_logger.create_slice_log_message(partition.to_slice()) + for record in partition.read(): + yield record.data + + +class FileBasedStreamPartition(Partition): + def __init__( + self, + stream: AbstractFileBasedStream, + _slice: Optional[Mapping[str, Any]], + message_repository: MessageRepository, + sync_mode: SyncMode, + cursor_field: Optional[List[str]], + state: Optional[MutableMapping[str, Any]], + cursor: FileBasedNoopCursor, + ): + self._stream = stream + self._slice = _slice + self._message_repository = message_repository + self._sync_mode = sync_mode + self._cursor_field = cursor_field + self._state = state + self._cursor = cursor + self._is_closed = False + + def read(self) -> Iterable[Record]: + try: + for record_data in self._stream.read_records( + cursor_field=self._cursor_field, + sync_mode=SyncMode.full_refresh, + stream_slice=copy.deepcopy(self._slice), + stream_state=self._state, + ): + if isinstance(record_data, Mapping): + data_to_return = dict(record_data) + self._stream.transformer.transform(data_to_return, self._stream.get_json_schema()) + yield Record(data_to_return, self.stream_name()) + else: + self._message_repository.emit_message(record_data) + except Exception as e: + display_message = self._stream.get_error_display_message(e) + if display_message: + raise ExceptionWithDisplayMessage(display_message) from e + else: + raise e + + def to_slice(self) -> Optional[Mapping[str, Any]]: + if self._slice is None: + return None + assert ( + len(self._slice["files"]) == 1 + ), f"Expected 1 file per partition but got {len(self._slice['files'])} for stream {self.stream_name()}" + file = self._slice["files"][0] + return {"files": [file]} + + def close(self) -> None: + self._cursor.close_partition(self) + self._is_closed = True + + def is_closed(self) -> bool: + return self._is_closed + + def __hash__(self) -> int: + if self._slice: + # Convert the slice to a string so that it can be hashed + if len(self._slice["files"]) != 1: + raise ValueError( + f"Slices for file-based streams should be of length 1, but got {len(self._slice['files'])}. This is unexpected. Please contact Support." + ) + else: + s = f"{self._slice['files'][0].last_modified.strftime('%Y-%m-%dT%H:%M:%S.%fZ')}_{self._slice['files'][0].uri}" + return hash((self._stream.name, s)) + else: + return hash(self._stream.name) + + def stream_name(self) -> str: + return self._stream.name + + def __repr__(self) -> str: + return f"FileBasedStreamPartition({self._stream.name}, {self._slice})" + + +class FileBasedStreamPartitionGenerator(PartitionGenerator): + def __init__( + self, + stream: AbstractFileBasedStream, + message_repository: MessageRepository, + sync_mode: SyncMode, + cursor_field: Optional[List[str]], + state: Optional[MutableMapping[str, Any]], + cursor: FileBasedNoopCursor, + ): + self._stream = stream + self._message_repository = message_repository + self._sync_mode = sync_mode + self._cursor_field = cursor_field + self._state = state + self._cursor = cursor + + def generate(self) -> Iterable[FileBasedStreamPartition]: + pending_partitions = [] + for _slice in self._stream.stream_slices(sync_mode=self._sync_mode, cursor_field=self._cursor_field, stream_state=self._state): + if _slice is not None: + pending_partitions.extend( + [ + FileBasedStreamPartition( + self._stream, + {"files": [copy.deepcopy(f)]}, + self._message_repository, + self._sync_mode, + self._cursor_field, + self._state, + self._cursor, + ) + for f in _slice.get("files", []) + ] + ) + self._cursor.set_pending_partitions(pending_partitions) + yield from pending_partitions diff --git a/airbyte-cdk/python/airbyte_cdk/sources/file_based/stream/concurrent/cursor.py b/airbyte-cdk/python/airbyte_cdk/sources/file_based/stream/concurrent/cursor.py new file mode 100644 index 000000000000..a0fd47044f3d --- /dev/null +++ b/airbyte-cdk/python/airbyte_cdk/sources/file_based/stream/concurrent/cursor.py @@ -0,0 +1,87 @@ +# +# Copyright (c) 2023 Airbyte, Inc., all rights reserved. +# +import logging +from abc import abstractmethod +from datetime import datetime +from typing import Any, Iterable, MutableMapping + +from airbyte_cdk.sources.file_based.config.file_based_stream_config import FileBasedStreamConfig +from airbyte_cdk.sources.file_based.remote_file import RemoteFile +from airbyte_cdk.sources.file_based.stream.cursor import AbstractFileBasedCursor +from airbyte_cdk.sources.file_based.types import StreamState +from airbyte_cdk.sources.streams.concurrent.cursor import Cursor +from airbyte_cdk.sources.streams.concurrent.partitions.partition import Partition +from airbyte_cdk.sources.streams.concurrent.partitions.record import Record + + +class AbstractFileBasedConcurrentCursor(Cursor, AbstractFileBasedCursor): + @property + @abstractmethod + def state(self) -> MutableMapping[str, Any]: + ... + + @abstractmethod + def add_file(self, file: RemoteFile) -> None: + ... + + @abstractmethod + def set_initial_state(self, value: StreamState) -> None: + ... + + @abstractmethod + def get_state(self) -> MutableMapping[str, Any]: + ... + + @abstractmethod + def get_start_time(self) -> datetime: + ... + + @abstractmethod + def get_files_to_sync(self, all_files: Iterable[RemoteFile], logger: logging.Logger) -> Iterable[RemoteFile]: + ... + + @abstractmethod + def observe(self, record: Record) -> None: + ... + + @abstractmethod + def close_partition(self, partition: Partition) -> None: + ... + + @abstractmethod + def set_pending_partitions(self, partitions: Iterable[Partition]) -> None: + ... + + +class FileBasedNoopCursor(AbstractFileBasedConcurrentCursor): + def __init__(self, stream_config: FileBasedStreamConfig, **kwargs: Any): + pass + + @property + def state(self) -> MutableMapping[str, Any]: + return {} + + def add_file(self, file: RemoteFile) -> None: + return None + + def set_initial_state(self, value: StreamState) -> None: + return None + + def get_state(self) -> MutableMapping[str, Any]: + return {} + + def get_start_time(self) -> datetime: + return datetime.min + + def get_files_to_sync(self, all_files: Iterable[RemoteFile], logger: logging.Logger) -> Iterable[RemoteFile]: + return [] + + def observe(self, record: Record) -> None: + return None + + def close_partition(self, partition: Partition) -> None: + return None + + def set_pending_partitions(self, partitions: Iterable[Partition]) -> None: + return None diff --git a/airbyte-cdk/python/airbyte_cdk/sources/file_based/stream/default_file_based_stream.py b/airbyte-cdk/python/airbyte_cdk/sources/file_based/stream/default_file_based_stream.py index f6e0ac8e0fe7..0107bd83498e 100644 --- a/airbyte-cdk/python/airbyte_cdk/sources/file_based/stream/default_file_based_stream.py +++ b/airbyte-cdk/python/airbyte_cdk/sources/file_based/stream/default_file_based_stream.py @@ -164,9 +164,13 @@ def get_json_schema(self) -> JsonSchema: try: schema = self._get_raw_json_schema() except (InvalidSchemaError, NoFilesMatchingError) as config_exception: + self.logger.exception(FileBasedSourceError.SCHEMA_INFERENCE_ERROR.value, exc_info=config_exception) raise AirbyteTracedException( - message=FileBasedSourceError.SCHEMA_INFERENCE_ERROR.value, exception=config_exception, failure_type=FailureType.config_error - ) from config_exception + internal_message="Please check the logged errors for more information.", + message=FileBasedSourceError.SCHEMA_INFERENCE_ERROR.value, + exception=AirbyteTracedException(exception=config_exception), + failure_type=FailureType.config_error, + ) except Exception as exc: raise SchemaInferenceError(FileBasedSourceError.SCHEMA_INFERENCE_ERROR, stream=self.name) from exc else: diff --git a/airbyte-cdk/python/airbyte_cdk/sources/streams/concurrent/abstract_stream_facade.py b/airbyte-cdk/python/airbyte_cdk/sources/streams/concurrent/abstract_stream_facade.py new file mode 100644 index 000000000000..18cacbc500d5 --- /dev/null +++ b/airbyte-cdk/python/airbyte_cdk/sources/streams/concurrent/abstract_stream_facade.py @@ -0,0 +1,37 @@ +# Copyright (c) 2023 Airbyte, Inc., all rights reserved. + +from abc import ABC, abstractmethod +from typing import Generic, Optional, TypeVar + +from airbyte_cdk.sources.streams.concurrent.exceptions import ExceptionWithDisplayMessage + +StreamType = TypeVar("StreamType") + + +class AbstractStreamFacade(Generic[StreamType], ABC): + @abstractmethod + def get_underlying_stream(self) -> StreamType: + """ + Return the underlying stream facade object. + """ + ... + + @property + def source_defined_cursor(self) -> bool: + # Streams must be aware of their cursor at instantiation time + return True + + def get_error_display_message(self, exception: BaseException) -> Optional[str]: + """ + Retrieves the user-friendly display message that corresponds to an exception. + This will be called when encountering an exception while reading records from the stream, and used to build the AirbyteTraceMessage. + + A display message will be returned if the exception is an instance of ExceptionWithDisplayMessage. + + :param exception: The exception that was raised + :return: A user-friendly message that indicates the cause of the error + """ + if isinstance(exception, ExceptionWithDisplayMessage): + return exception.display_message + else: + return None diff --git a/airbyte-cdk/python/airbyte_cdk/sources/streams/concurrent/adapters.py b/airbyte-cdk/python/airbyte_cdk/sources/streams/concurrent/adapters.py index f8a5e3ed65e3..ba13a467238d 100644 --- a/airbyte-cdk/python/airbyte_cdk/sources/streams/concurrent/adapters.py +++ b/airbyte-cdk/python/airbyte_cdk/sources/streams/concurrent/adapters.py @@ -14,7 +14,7 @@ from airbyte_cdk.sources.message import MessageRepository from airbyte_cdk.sources.streams import Stream from airbyte_cdk.sources.streams.availability_strategy import AvailabilityStrategy -from airbyte_cdk.sources.streams.concurrent.abstract_stream import AbstractStream +from airbyte_cdk.sources.streams.concurrent.abstract_stream_facade import AbstractStreamFacade from airbyte_cdk.sources.streams.concurrent.availability_strategy import ( AbstractAvailabilityStrategy, StreamAvailability, @@ -24,6 +24,7 @@ from airbyte_cdk.sources.streams.concurrent.cursor import Cursor, NoopCursor from airbyte_cdk.sources.streams.concurrent.default_stream import DefaultStream from airbyte_cdk.sources.streams.concurrent.exceptions import ExceptionWithDisplayMessage +from airbyte_cdk.sources.streams.concurrent.helpers import get_cursor_field_from_stream, get_primary_key_from_stream from airbyte_cdk.sources.streams.concurrent.partitions.partition import Partition from airbyte_cdk.sources.streams.concurrent.partitions.partition_generator import PartitionGenerator from airbyte_cdk.sources.streams.concurrent.partitions.record import Record @@ -38,7 +39,7 @@ @deprecated("This class is experimental. Use at your own risk.") -class StreamFacade(Stream): +class StreamFacade(AbstractStreamFacade[DefaultStream], Stream): """ The StreamFacade is a Stream that wraps an AbstractStream and exposes it as a Stream. @@ -62,8 +63,8 @@ def create_from_stream( :param max_workers: The maximum number of worker thread to use :return: """ - pk = cls._get_primary_key_from_stream(stream.primary_key) - cursor_field = cls._get_cursor_field_from_stream(stream) + pk = get_primary_key_from_stream(stream.primary_key) + cursor_field = get_cursor_field_from_stream(stream) if not source.message_repository: raise ValueError( @@ -104,33 +105,7 @@ def state(self, value: Mapping[str, Any]) -> None: if "state" in dir(self._legacy_stream): self._legacy_stream.state = value # type: ignore # validating `state` is attribute of stream using `if` above - @classmethod - def _get_primary_key_from_stream(cls, stream_primary_key: Optional[Union[str, List[str], List[List[str]]]]) -> List[str]: - if stream_primary_key is None: - return [] - elif isinstance(stream_primary_key, str): - return [stream_primary_key] - elif isinstance(stream_primary_key, list): - if len(stream_primary_key) > 0 and all(isinstance(k, str) for k in stream_primary_key): - return stream_primary_key # type: ignore # We verified all items in the list are strings - else: - raise ValueError(f"Nested primary keys are not supported. Found {stream_primary_key}") - else: - raise ValueError(f"Invalid type for primary key: {stream_primary_key}") - - @classmethod - def _get_cursor_field_from_stream(cls, stream: Stream) -> Optional[str]: - if isinstance(stream.cursor_field, list): - if len(stream.cursor_field) > 1: - raise ValueError(f"Nested cursor fields are not supported. Got {stream.cursor_field} for {stream.name}") - elif len(stream.cursor_field) == 0: - return None - else: - return stream.cursor_field[0] - else: - return stream.cursor_field - - def __init__(self, stream: AbstractStream, legacy_stream: Stream, cursor: Cursor, slice_logger: SliceLogger, logger: logging.Logger): + def __init__(self, stream: DefaultStream, legacy_stream: Stream, cursor: Cursor, slice_logger: SliceLogger, logger: logging.Logger): """ :param stream: The underlying AbstractStream """ @@ -178,7 +153,7 @@ def read_records( yield from self._read_records() except Exception as exc: if hasattr(self._cursor, "state"): - state = self._cursor.state + state = str(self._cursor.state) else: # This shouldn't happen if the ConcurrentCursor was used state = "unknown; no state attribute was available on the cursor" @@ -210,11 +185,6 @@ def cursor_field(self) -> Union[str, List[str]]: else: return self._abstract_stream.cursor_field - @property - def source_defined_cursor(self) -> bool: - # Streams must be aware of their cursor at instantiation time - return True - @lru_cache(maxsize=None) def get_json_schema(self) -> Mapping[str, Any]: return self._abstract_stream.get_json_schema() @@ -233,27 +203,15 @@ def check_availability(self, logger: logging.Logger, source: Optional["Source"] availability = self._abstract_stream.check_availability() return availability.is_available(), availability.message() - def get_error_display_message(self, exception: BaseException) -> Optional[str]: - """ - Retrieves the user-friendly display message that corresponds to an exception. - This will be called when encountering an exception while reading records from the stream, and used to build the AirbyteTraceMessage. - - A display message will be returned if the exception is an instance of ExceptionWithDisplayMessage. - - :param exception: The exception that was raised - :return: A user-friendly message that indicates the cause of the error - """ - if isinstance(exception, ExceptionWithDisplayMessage): - return exception.display_message - else: - return None - def as_airbyte_stream(self) -> AirbyteStream: return self._abstract_stream.as_airbyte_stream() def log_stream_sync_configuration(self) -> None: self._abstract_stream.log_stream_sync_configuration() + def get_underlying_stream(self) -> DefaultStream: + return self._abstract_stream + class StreamPartition(Partition): """ diff --git a/airbyte-cdk/python/airbyte_cdk/sources/streams/concurrent/helpers.py b/airbyte-cdk/python/airbyte_cdk/sources/streams/concurrent/helpers.py new file mode 100644 index 000000000000..ad7722726498 --- /dev/null +++ b/airbyte-cdk/python/airbyte_cdk/sources/streams/concurrent/helpers.py @@ -0,0 +1,31 @@ +# Copyright (c) 2023 Airbyte, Inc., all rights reserved. + +from typing import List, Optional, Union + +from airbyte_cdk.sources.streams import Stream + + +def get_primary_key_from_stream(stream_primary_key: Optional[Union[str, List[str], List[List[str]]]]) -> List[str]: + if stream_primary_key is None: + return [] + elif isinstance(stream_primary_key, str): + return [stream_primary_key] + elif isinstance(stream_primary_key, list): + if len(stream_primary_key) > 0 and all(isinstance(k, str) for k in stream_primary_key): + return stream_primary_key # type: ignore # We verified all items in the list are strings + else: + raise ValueError(f"Nested primary keys are not supported. Found {stream_primary_key}") + else: + raise ValueError(f"Invalid type for primary key: {stream_primary_key}") + + +def get_cursor_field_from_stream(stream: Stream) -> Optional[str]: + if isinstance(stream.cursor_field, list): + if len(stream.cursor_field) > 1: + raise ValueError(f"Nested cursor fields are not supported. Got {stream.cursor_field} for {stream.name}") + elif len(stream.cursor_field) == 0: + return None + else: + return stream.cursor_field[0] + else: + return stream.cursor_field diff --git a/airbyte-cdk/python/unit_tests/sources/file_based/file_types/test_csv_parser.py b/airbyte-cdk/python/unit_tests/sources/file_based/file_types/test_csv_parser.py index 9596cd84c598..3dfea9bb17df 100644 --- a/airbyte-cdk/python/unit_tests/sources/file_based/file_types/test_csv_parser.py +++ b/airbyte-cdk/python/unit_tests/sources/file_based/file_types/test_csv_parser.py @@ -447,11 +447,13 @@ def test_given_generator_closed_when_read_data_then_unregister_dialect(self) -> .build() ) + dialects_before = set(csv.list_dialects()) data_generator = self._read_data() next(data_generator) - assert f"{self._CONFIG_NAME}_config_dialect" in csv.list_dialects() + [new_dialect] = set(csv.list_dialects()) - dialects_before + assert self._CONFIG_NAME in new_dialect data_generator.close() - assert f"{self._CONFIG_NAME}_config_dialect" not in csv.list_dialects() + assert new_dialect not in csv.list_dialects() def test_given_too_many_values_for_columns_when_read_data_then_raise_exception_and_unregister_dialect(self) -> None: self._stream_reader.open_file.return_value = ( @@ -466,13 +468,15 @@ def test_given_too_many_values_for_columns_when_read_data_then_raise_exception_a .build() ) + dialects_before = set(csv.list_dialects()) data_generator = self._read_data() next(data_generator) - assert f"{self._CONFIG_NAME}_config_dialect" in csv.list_dialects() + [new_dialect] = set(csv.list_dialects()) - dialects_before + assert self._CONFIG_NAME in new_dialect with pytest.raises(RecordParseError): next(data_generator) - assert f"{self._CONFIG_NAME}_config_dialect" not in csv.list_dialects() + assert new_dialect not in csv.list_dialects() def test_given_too_few_values_for_columns_when_read_data_then_raise_exception_and_unregister_dialect(self) -> None: self._stream_reader.open_file.return_value = ( @@ -487,13 +491,15 @@ def test_given_too_few_values_for_columns_when_read_data_then_raise_exception_an .build() ) + dialects_before = set(csv.list_dialects()) data_generator = self._read_data() next(data_generator) - assert f"{self._CONFIG_NAME}_config_dialect" in csv.list_dialects() + [new_dialect] = set(csv.list_dialects()) - dialects_before + assert self._CONFIG_NAME in new_dialect with pytest.raises(RecordParseError): next(data_generator) - assert f"{self._CONFIG_NAME}_config_dialect" not in csv.list_dialects() + assert new_dialect not in csv.list_dialects() def _read_data(self) -> Generator[Dict[str, str], None, None]: data_generator = self._csv_reader.read_data( diff --git a/airbyte-cdk/python/unit_tests/sources/file_based/in_memory_files_source.py b/airbyte-cdk/python/unit_tests/sources/file_based/in_memory_files_source.py index 643461471fd5..5db12fc5679c 100644 --- a/airbyte-cdk/python/unit_tests/sources/file_based/in_memory_files_source.py +++ b/airbyte-cdk/python/unit_tests/sources/file_based/in_memory_files_source.py @@ -26,11 +26,14 @@ from airbyte_cdk.sources.file_based.remote_file import RemoteFile from airbyte_cdk.sources.file_based.schema_validation_policies import DEFAULT_SCHEMA_VALIDATION_POLICIES, AbstractSchemaValidationPolicy from airbyte_cdk.sources.file_based.stream.cursor import AbstractFileBasedCursor, DefaultFileBasedCursor +from airbyte_cdk.sources.source import TState from avro import datafile from pydantic import AnyUrl class InMemoryFilesSource(FileBasedSource): + _concurrency_level = 10 + def __init__( self, files: Mapping[str, Any], @@ -41,6 +44,8 @@ def __init__( parsers: Mapping[str, FileTypeParser], stream_reader: Optional[AbstractFileBasedStreamReader], catalog: Optional[Mapping[str, Any]], + config: Optional[Mapping[str, Any]], + state: Optional[TState], file_write_options: Mapping[str, Any], cursor_cls: Optional[AbstractFileBasedCursor], ): @@ -48,6 +53,9 @@ def __init__( self.files = files self.file_type = file_type self.catalog = catalog + self.configured_catalog = ConfiguredAirbyteCatalog(streams=self.catalog["streams"]) if self.catalog else None + self.config = config + self.state = state # Source setup stream_reader = stream_reader or InMemoryFilesStreamReader(files=files, file_type=file_type, file_write_options=file_write_options) @@ -55,7 +63,9 @@ def __init__( super().__init__( stream_reader, spec_class=InMemorySpec, - catalog_path="fake_path" if catalog else None, + catalog=self.configured_catalog, + config=self.config, + state=self.state, availability_strategy=availability_strategy, discovery_policy=discovery_policy or DefaultDiscoveryPolicy(), parsers=parsers, @@ -64,7 +74,7 @@ def __init__( ) def read_catalog(self, catalog_path: str) -> ConfiguredAirbyteCatalog: - return ConfiguredAirbyteCatalog(streams=self.catalog["streams"]) if self.catalog else None + return self.configured_catalog class InMemoryFilesStreamReader(AbstractFileBasedStreamReader): diff --git a/airbyte-cdk/python/unit_tests/sources/file_based/scenarios/csv_scenarios.py b/airbyte-cdk/python/unit_tests/sources/file_based/scenarios/csv_scenarios.py index 77164c83d8d8..0f3c0f0ef2a7 100644 --- a/airbyte-cdk/python/unit_tests/sources/file_based/scenarios/csv_scenarios.py +++ b/airbyte-cdk/python/unit_tests/sources/file_based/scenarios/csv_scenarios.py @@ -847,7 +847,7 @@ "read": [ { "level": "ERROR", - "message": f"{FileBasedSourceError.ERROR_PARSING_RECORD.value} stream=stream1 file=a.csv line_no=1 n_skipped=0", + "message": f"{FileBasedSourceError.INVALID_SCHEMA_ERROR.value} stream=stream1 file=a.csv line_no=1 n_skipped=0", }, ] } @@ -1471,28 +1471,7 @@ } ) .set_expected_discover_error(AirbyteTracedException, FileBasedSourceError.SCHEMA_INFERENCE_ERROR.value) - .set_expected_records( - [ - { - "data": { - "col1": "val11", - "col2": "val12", - "_ab_source_file_last_modified": "2023-06-05T03:54:07.000000Z", - "_ab_source_file_url": "a.csv", - }, - "stream": "stream1", - }, - { - "data": { - "col1": "val21", - "col2": "val22", - "_ab_source_file_last_modified": "2023-06-05T03:54:07.000000Z", - "_ab_source_file_url": "a.csv", - }, - "stream": "stream1", - }, - ] - ) + .set_expected_records([]) ).build() schemaless_csv_scenario: TestScenario[InMemoryFilesSource] = ( @@ -1766,7 +1745,7 @@ } ) .set_expected_check_status("FAILED") - .set_expected_check_error(ConfigValidationError, FileBasedSourceError.CONFIG_VALIDATION_ERROR.value) + .set_expected_check_error(AirbyteTracedException, FileBasedSourceError.CONFIG_VALIDATION_ERROR.value) .set_expected_discover_error(ConfigValidationError, FileBasedSourceError.CONFIG_VALIDATION_ERROR.value) .set_expected_read_error(ConfigValidationError, FileBasedSourceError.CONFIG_VALIDATION_ERROR.value) ).build() @@ -1854,7 +1833,7 @@ } ) .set_expected_check_status("FAILED") - .set_expected_check_error(ConfigValidationError, FileBasedSourceError.CONFIG_VALIDATION_ERROR.value) + .set_expected_check_error(AirbyteTracedException, FileBasedSourceError.CONFIG_VALIDATION_ERROR.value) .set_expected_discover_error(ConfigValidationError, FileBasedSourceError.CONFIG_VALIDATION_ERROR.value) .set_expected_read_error(ConfigValidationError, FileBasedSourceError.CONFIG_VALIDATION_ERROR.value) ).build() @@ -3029,6 +3008,7 @@ .set_file_type("csv") ) .set_expected_check_status("FAILED") + .set_expected_check_error(AirbyteTracedException, FileBasedSourceError.EMPTY_STREAM.value) .set_expected_catalog( { "streams": [ diff --git a/airbyte-cdk/python/unit_tests/sources/file_based/scenarios/file_based_source_builder.py b/airbyte-cdk/python/unit_tests/sources/file_based/scenarios/file_based_source_builder.py index 90deb31fe41b..f3d72ab67e7a 100644 --- a/airbyte-cdk/python/unit_tests/sources/file_based/scenarios/file_based_source_builder.py +++ b/airbyte-cdk/python/unit_tests/sources/file_based/scenarios/file_based_source_builder.py @@ -14,6 +14,7 @@ from airbyte_cdk.sources.file_based.file_types.file_type_parser import FileTypeParser from airbyte_cdk.sources.file_based.schema_validation_policies import AbstractSchemaValidationPolicy from airbyte_cdk.sources.file_based.stream.cursor import AbstractFileBasedCursor +from airbyte_cdk.sources.source import TState from unit_tests.sources.file_based.in_memory_files_source import InMemoryFilesSource from unit_tests.sources.file_based.scenarios.scenario_builder import SourceBuilder @@ -29,8 +30,10 @@ def __init__(self) -> None: self._stream_reader: Optional[AbstractFileBasedStreamReader] = None self._file_write_options: Mapping[str, Any] = {} self._cursor_cls: Optional[Type[AbstractFileBasedCursor]] = None + self._config: Optional[Mapping[str, Any]] = None + self._state: Optional[TState] = None - def build(self, configured_catalog: Optional[Mapping[str, Any]]) -> InMemoryFilesSource: + def build(self, configured_catalog: Optional[Mapping[str, Any]], config: Optional[Mapping[str, Any]], state: Optional[TState]) -> InMemoryFilesSource: if self._file_type is None: raise ValueError("file_type is not set") return InMemoryFilesSource( @@ -42,6 +45,8 @@ def build(self, configured_catalog: Optional[Mapping[str, Any]]) -> InMemoryFile self._parsers, self._stream_reader, configured_catalog, + config, + state, self._file_write_options, self._cursor_cls, ) diff --git a/airbyte-cdk/python/unit_tests/sources/file_based/scenarios/jsonl_scenarios.py b/airbyte-cdk/python/unit_tests/sources/file_based/scenarios/jsonl_scenarios.py index b4a447c4f0c0..2998f3deb6cc 100644 --- a/airbyte-cdk/python/unit_tests/sources/file_based/scenarios/jsonl_scenarios.py +++ b/airbyte-cdk/python/unit_tests/sources/file_based/scenarios/jsonl_scenarios.py @@ -485,14 +485,10 @@ } ) .set_expected_records( - [ - { - "data": {"col1": "val1", "_ab_source_file_last_modified": "2023-06-05T03:54:07.000000Z", "_ab_source_file_url": "a.jsonl"}, - "stream": "stream1", - }, - ] + [] ) .set_expected_discover_error(AirbyteTracedException, FileBasedSourceError.SCHEMA_INFERENCE_ERROR.value) + .set_expected_read_error(AirbyteTracedException, "Please check the logged errors for more information.") .set_expected_logs( { "read": [ diff --git a/airbyte-cdk/python/unit_tests/sources/file_based/scenarios/parquet_scenarios.py b/airbyte-cdk/python/unit_tests/sources/file_based/scenarios/parquet_scenarios.py index 0852de4a361a..30ffa263f88e 100644 --- a/airbyte-cdk/python/unit_tests/sources/file_based/scenarios/parquet_scenarios.py +++ b/airbyte-cdk/python/unit_tests/sources/file_based/scenarios/parquet_scenarios.py @@ -729,6 +729,7 @@ .set_expected_records([]) .set_expected_logs({"read": [{"level": "ERROR", "message": "Error parsing record"}]}) .set_expected_discover_error(AirbyteTracedException, "Error inferring schema from files") + .set_expected_read_error(AirbyteTracedException, "Please check the logged errors for more information.") .set_expected_catalog( { "streams": [ diff --git a/airbyte-cdk/python/unit_tests/sources/file_based/scenarios/scenario_builder.py b/airbyte-cdk/python/unit_tests/sources/file_based/scenarios/scenario_builder.py index 75feaf360595..7e48af119f8b 100644 --- a/airbyte-cdk/python/unit_tests/sources/file_based/scenarios/scenario_builder.py +++ b/airbyte-cdk/python/unit_tests/sources/file_based/scenarios/scenario_builder.py @@ -8,6 +8,7 @@ from airbyte_cdk.models import AirbyteAnalyticsTraceMessage, SyncMode from airbyte_cdk.sources import AbstractSource +from airbyte_cdk.sources.source import TState from airbyte_protocol.models import ConfiguredAirbyteCatalog @@ -26,7 +27,7 @@ class SourceBuilder(ABC, Generic[SourceType]): """ @abstractmethod - def build(self, configured_catalog: Optional[Mapping[str, Any]]) -> SourceType: + def build(self, configured_catalog: Optional[Mapping[str, Any]], config: Optional[Mapping[str, Any]], state: Optional[TState]) -> SourceType: raise NotImplementedError() @@ -80,11 +81,11 @@ def configured_catalog(self, sync_mode: SyncMode) -> Optional[Mapping[str, Any]] return self.catalog.dict() # type: ignore # dict() is not typed catalog: Mapping[str, Any] = {"streams": []} - for stream in self.source.streams(self.config): + for stream in catalog["streams"]: catalog["streams"].append( { "stream": { - "name": stream.name, + "name": stream["name"], "json_schema": {}, "supported_sync_modes": [sync_mode.value], }, @@ -152,7 +153,7 @@ def set_expected_logs(self, expected_logs: Mapping[str, List[Mapping[str, Any]]] self._expected_logs = expected_logs return self - def set_expected_records(self, expected_records: List[Mapping[str, Any]]) -> "TestScenarioBuilder[SourceType]": + def set_expected_records(self, expected_records: Optional[List[Mapping[str, Any]]]) -> "TestScenarioBuilder[SourceType]": self._expected_records = expected_records return self @@ -191,7 +192,9 @@ def build(self) -> "TestScenario[SourceType]": if self.source_builder is None: raise ValueError("source_builder is not set") source = self.source_builder.build( - self._configured_catalog(SyncMode.incremental if self._incremental_scenario_config else SyncMode.full_refresh) + self._configured_catalog(SyncMode.incremental if self._incremental_scenario_config else SyncMode.full_refresh), + self._config, + self._incremental_scenario_config.input_state if self._incremental_scenario_config else None, ) return TestScenario( self._name, diff --git a/airbyte-cdk/python/unit_tests/sources/file_based/scenarios/validation_policy_scenarios.py b/airbyte-cdk/python/unit_tests/sources/file_based/scenarios/validation_policy_scenarios.py index 9ac880b11fe5..4ff096954523 100644 --- a/airbyte-cdk/python/unit_tests/sources/file_based/scenarios/validation_policy_scenarios.py +++ b/airbyte-cdk/python/unit_tests/sources/file_based/scenarios/validation_policy_scenarios.py @@ -661,29 +661,7 @@ ] } ) - .set_expected_records( - [ - { - "data": { - "col1": "val_a_11", - "col2": 1, - "_ab_source_file_last_modified": "2023-06-05T03:54:07.000000Z", - "_ab_source_file_url": "a.csv", - }, - "stream": "stream1", - }, - { - "data": { - "col1": "val_a_12", - "col2": 2, - "_ab_source_file_last_modified": "2023-06-05T03:54:07.000000Z", - "_ab_source_file_url": "a.csv", - }, - "stream": "stream1", - }, - # No records past that because the first record for the second file did not conform to the schema - ] - ) + .set_expected_records(None) # When syncing streams concurrently we don't know how many records will be emitted before the sync stops .set_expected_logs( { "read": [ @@ -722,56 +700,7 @@ ] } ) - .set_expected_records( - [ - { - "data": { - "col1": "val_aa1_11", - "col2": 1, - "_ab_source_file_last_modified": "2023-06-05T03:54:07.000000Z", - "_ab_source_file_url": "a/a1.csv", - }, - "stream": "stream1", - }, - { - "data": { - "col1": "val_aa1_12", - "col2": 2, - "_ab_source_file_last_modified": "2023-06-05T03:54:07.000000Z", - "_ab_source_file_url": "a/a1.csv", - }, - "stream": "stream1", - }, - # {"data": {"col1": "val_aa2_11", "col2": "this is text that will trigger validation policy", "_ab_source_file_last_modified": "2023-06-05T03:54:07.000000Z", "_ab_source_file_url": "a/a2.csv"}, "stream": "stream1"}, - # {"data": {"col1": "val_aa2_12", "col2": 2, "_ab_source_file_last_modified": "2023-06-05T03:54:07.000000Z", "_ab_source_file_url": "a/a2.csv"}, "stream": "stream1"}, - # {"data": {"col1": "val_aa3_11", "_ab_source_file_last_modified": "2023-06-05T03:54:07.000000Z", "_ab_source_file_url": "a/a3.csv"}, "stream": "stream1"}, - # {"data": {"col1": "val_aa3_12", None: "val_aa3_22", "_ab_source_file_last_modified": "2023-06-05T03:54:07.000000Z", "_ab_source_file_url": "a/a3.csv"}, "stream": "stream1"}, - # {"data": {"col1": "val_aa3_13", "_ab_source_file_last_modified": "2023-06-05T03:54:07.000000Z", "_ab_source_file_url": "a/a3.csv"}, "stream": "stream1"}, - # {"data": {"col1": "val_aa4_11", "_ab_source_file_last_modified": "2023-06-05T03:54:07.000000Z", "_ab_source_file_url": "a/a4.csv"}, "stream": "stream1"}, - { - "data": { - "col1": "val_bb1_11", - "col2": 1, - "_ab_source_file_last_modified": "2023-06-05T03:54:07.000000Z", - "_ab_source_file_url": "b/b1.csv", - }, - "stream": "stream2", - }, - { - "data": { - "col1": "val_bb1_12", - "col2": 2, - "_ab_source_file_last_modified": "2023-06-05T03:54:07.000000Z", - "_ab_source_file_url": "b/b1.csv", - }, - "stream": "stream2", - }, - # {"data": {"col1": "val_bb2_11", "col2": "this is text that will trigger validation policy", "_ab_source_file_last_modified": "2023-06-05T03:54:07.000000Z", "_ab_source_file_url": "b/b2.csv"}, "stream": "stream2"}, - # {"data": {"col1": "val_bb2_12", "col2": 2, "_ab_source_file_last_modified": "2023-06-05T03:54:07.000000Z", "_ab_source_file_url": "b/b2.csv"}, "stream": "stream2"}, - # {"data": {"col1": "val_bb3_11", "_ab_source_file_last_modified": "2023-06-05T03:54:07.000000Z", "_ab_source_file_url": "b/b3.csv"}, "stream": "stream2"}, - # {"data": {"col1": "val_bb3_12", "_ab_source_file_last_modified": "2023-06-05T03:54:07.000000Z", "_ab_source_file_url": "b/b3.csv"}, "stream": "stream2"}, - ] - ) + .set_expected_records(None) # When syncing streams concurrently we don't know how many records will be emitted before the sync stops .set_expected_logs( { "read": [ diff --git a/airbyte-cdk/python/unit_tests/sources/file_based/stream/concurrent/test_adapters.py b/airbyte-cdk/python/unit_tests/sources/file_based/stream/concurrent/test_adapters.py new file mode 100644 index 000000000000..e63f950b5b2a --- /dev/null +++ b/airbyte-cdk/python/unit_tests/sources/file_based/stream/concurrent/test_adapters.py @@ -0,0 +1,365 @@ +# +# Copyright (c) 2023 Airbyte, Inc., all rights reserved. +# +import logging +import unittest +from datetime import datetime +from unittest.mock import MagicMock, Mock + +import pytest +from airbyte_cdk.models import AirbyteLogMessage, AirbyteMessage, AirbyteStream, Level, SyncMode +from airbyte_cdk.models import Type as MessageType +from airbyte_cdk.sources.file_based.availability_strategy import DefaultFileBasedAvailabilityStrategy +from airbyte_cdk.sources.file_based.config.csv_format import CsvFormat +from airbyte_cdk.sources.file_based.config.file_based_stream_config import FileBasedStreamConfig +from airbyte_cdk.sources.file_based.discovery_policy import DefaultDiscoveryPolicy +from airbyte_cdk.sources.file_based.exceptions import FileBasedErrorsCollector +from airbyte_cdk.sources.file_based.file_types import default_parsers +from airbyte_cdk.sources.file_based.remote_file import RemoteFile +from airbyte_cdk.sources.file_based.schema_validation_policies import EmitRecordPolicy +from airbyte_cdk.sources.file_based.stream import DefaultFileBasedStream +from airbyte_cdk.sources.file_based.stream.concurrent.adapters import ( + FileBasedStreamFacade, + FileBasedStreamPartition, + FileBasedStreamPartitionGenerator, +) +from airbyte_cdk.sources.file_based.stream.concurrent.cursor import FileBasedNoopCursor +from airbyte_cdk.sources.message import InMemoryMessageRepository +from airbyte_cdk.sources.streams.concurrent.cursor import Cursor +from airbyte_cdk.sources.streams.concurrent.exceptions import ExceptionWithDisplayMessage +from airbyte_cdk.sources.streams.concurrent.partitions.record import Record +from airbyte_cdk.sources.utils.slice_logger import SliceLogger +from airbyte_cdk.sources.utils.transform import TransformConfig, TypeTransformer +from freezegun import freeze_time + +_ANY_SYNC_MODE = SyncMode.full_refresh +_ANY_STATE = {"state_key": "state_value"} +_ANY_CURSOR_FIELD = ["a", "cursor", "key"] +_STREAM_NAME = "stream" +_ANY_CURSOR = Mock(spec=FileBasedNoopCursor) + + +@pytest.mark.parametrize( + "sync_mode", + [ + pytest.param(SyncMode.full_refresh, id="test_full_refresh"), + pytest.param(SyncMode.incremental, id="test_incremental"), + ], +) +def test_file_based_stream_partition_generator(sync_mode): + stream = Mock() + message_repository = Mock() + stream_slices = [{"files": [RemoteFile(uri="1", last_modified=datetime.now())]}, + {"files": [RemoteFile(uri="2", last_modified=datetime.now())]}] + stream.stream_slices.return_value = stream_slices + + partition_generator = FileBasedStreamPartitionGenerator(stream, message_repository, _ANY_SYNC_MODE, _ANY_CURSOR_FIELD, _ANY_STATE, _ANY_CURSOR) + + partitions = list(partition_generator.generate()) + slices = [partition.to_slice() for partition in partitions] + assert slices == stream_slices + stream.stream_slices.assert_called_once_with(sync_mode=_ANY_SYNC_MODE, cursor_field=_ANY_CURSOR_FIELD, stream_state=_ANY_STATE) + + +@pytest.mark.parametrize( + "transformer, expected_records", + [ + pytest.param( + TypeTransformer(TransformConfig.NoTransform), + [Record({"data": "1"}, _STREAM_NAME), Record({"data": "2"}, _STREAM_NAME)], + id="test_no_transform", + ), + pytest.param( + TypeTransformer(TransformConfig.DefaultSchemaNormalization), + [Record({"data": 1}, _STREAM_NAME), Record({"data": 2}, _STREAM_NAME)], + id="test_default_transform", + ), + ], +) +def test_file_based_stream_partition(transformer, expected_records): + stream = Mock() + stream.name = _STREAM_NAME + stream.get_json_schema.return_value = {"type": "object", "properties": {"data": {"type": ["integer"]}}} + stream.transformer = transformer + message_repository = InMemoryMessageRepository() + _slice = None + sync_mode = SyncMode.full_refresh + cursor_field = None + state = None + partition = FileBasedStreamPartition(stream, _slice, message_repository, sync_mode, cursor_field, state, _ANY_CURSOR) + + a_log_message = AirbyteMessage( + type=MessageType.LOG, + log=AirbyteLogMessage( + level=Level.INFO, + message='slice:{"partition": 1}', + ), + ) + + stream_data = [a_log_message, {"data": "1"}, {"data": "2"}] + stream.read_records.return_value = stream_data + + records = list(partition.read()) + messages = list(message_repository.consume_queue()) + + assert records == expected_records + assert messages == [a_log_message] + + +@pytest.mark.parametrize( + "exception_type, expected_display_message", + [ + pytest.param(Exception, None, id="test_exception_no_display_message"), + pytest.param(ExceptionWithDisplayMessage, "display_message", id="test_exception_no_display_message"), + ], +) +def test_file_based_stream_partition_raising_exception(exception_type, expected_display_message): + stream = Mock() + stream.get_error_display_message.return_value = expected_display_message + + message_repository = InMemoryMessageRepository() + _slice = None + + partition = FileBasedStreamPartition(stream, _slice, message_repository, _ANY_SYNC_MODE, _ANY_CURSOR_FIELD, _ANY_STATE, _ANY_CURSOR) + + stream.read_records.side_effect = Exception() + + with pytest.raises(exception_type) as e: + list(partition.read()) + if isinstance(e, ExceptionWithDisplayMessage): + assert e.display_message == "display message" + + +@freeze_time("2023-06-09T00:00:00Z") +@pytest.mark.parametrize( + "_slice, expected_hash", + [ + pytest.param({"files": [RemoteFile(uri="1", last_modified=datetime.strptime("2023-06-09T00:00:00Z", "%Y-%m-%dT%H:%M:%SZ"))]}, hash(("stream", "2023-06-09T00:00:00.000000Z_1")), id="test_hash_with_slice"), + pytest.param(None, hash("stream"), id="test_hash_no_slice"), + ], +) +def test_file_based_stream_partition_hash(_slice, expected_hash): + stream = Mock() + stream.name = "stream" + partition = FileBasedStreamPartition(stream, _slice, Mock(), _ANY_SYNC_MODE, _ANY_CURSOR_FIELD, _ANY_STATE, _ANY_CURSOR) + + _hash = partition.__hash__() + assert _hash == expected_hash + + +class StreamFacadeTest(unittest.TestCase): + def setUp(self): + self._abstract_stream = Mock() + self._abstract_stream.name = "stream" + self._abstract_stream.as_airbyte_stream.return_value = AirbyteStream( + name="stream", + json_schema={"type": "object"}, + supported_sync_modes=[SyncMode.full_refresh], + ) + self._legacy_stream = DefaultFileBasedStream( + cursor=FileBasedNoopCursor(MagicMock()), + config=FileBasedStreamConfig(name="stream", format=CsvFormat()), + catalog_schema={}, + stream_reader=MagicMock(), + availability_strategy=DefaultFileBasedAvailabilityStrategy(MagicMock()), + discovery_policy=DefaultDiscoveryPolicy(), + parsers=default_parsers, + validation_policy=EmitRecordPolicy(), + errors_collector=FileBasedErrorsCollector(), + ) + self._cursor = Mock(spec=Cursor) + self._logger = Mock() + self._slice_logger = Mock() + self._slice_logger.should_log_slice_message.return_value = False + self._facade = FileBasedStreamFacade(self._abstract_stream, self._legacy_stream, self._cursor, self._slice_logger, self._logger) + self._source = Mock() + + self._stream = Mock() + self._stream.primary_key = "id" + + def test_name_is_delegated_to_wrapped_stream(self): + assert self._facade.name == self._abstract_stream.name + + def test_cursor_field_is_a_string(self): + self._abstract_stream.cursor_field = "cursor_field" + assert self._facade.cursor_field == "cursor_field" + + def test_source_defined_cursor_is_true(self): + assert self._facade.source_defined_cursor + + def test_json_schema_is_delegated_to_wrapped_stream(self): + json_schema = {"type": "object"} + self._abstract_stream.get_json_schema.return_value = json_schema + assert self._facade.get_json_schema() == json_schema + self._abstract_stream.get_json_schema.assert_called_once_with() + + def test_given_cursor_is_noop_when_supports_incremental_then_return_legacy_stream_response(self): + assert ( + FileBasedStreamFacade( + self._abstract_stream, self._legacy_stream, _ANY_CURSOR, Mock(spec=SliceLogger), Mock(spec=logging.Logger) + ).supports_incremental + == self._legacy_stream.supports_incremental + ) + + def test_given_cursor_is_not_noop_when_supports_incremental_then_return_true(self): + assert FileBasedStreamFacade( + self._abstract_stream, self._legacy_stream, Mock(spec=Cursor), Mock(spec=SliceLogger), Mock(spec=logging.Logger) + ).supports_incremental + + def test_full_refresh(self): + expected_stream_data = [{"data": 1}, {"data": 2}] + records = [Record(data, "stream") for data in expected_stream_data] + + partition = Mock() + partition.read.return_value = records + self._abstract_stream.generate_partitions.return_value = [partition] + + actual_stream_data = list(self._facade.read_records(SyncMode.full_refresh, None, {}, None)) + + assert actual_stream_data == expected_stream_data + + def test_read_records_full_refresh(self): + expected_stream_data = [{"data": 1}, {"data": 2}] + records = [Record(data, "stream") for data in expected_stream_data] + partition = Mock() + partition.read.return_value = records + self._abstract_stream.generate_partitions.return_value = [partition] + + actual_stream_data = list(self._facade.read_full_refresh(None, None, None)) + + assert actual_stream_data == expected_stream_data + + def test_read_records_incremental(self): + expected_stream_data = [{"data": 1}, {"data": 2}] + records = [Record(data, "stream") for data in expected_stream_data] + partition = Mock() + partition.read.return_value = records + self._abstract_stream.generate_partitions.return_value = [partition] + + actual_stream_data = list(self._facade.read_incremental(None, None, None, None, None, None, None)) + + assert actual_stream_data == expected_stream_data + + def test_create_from_stream_stream(self): + stream = Mock() + stream.name = "stream" + stream.primary_key = "id" + stream.cursor_field = "cursor" + + facade = FileBasedStreamFacade.create_from_stream(stream, self._source, self._logger, _ANY_STATE, self._cursor) + + assert facade.name == "stream" + assert facade.cursor_field == "cursor" + assert facade._abstract_stream._primary_key == ["id"] + + def test_create_from_stream_stream_with_none_primary_key(self): + stream = Mock() + stream.name = "stream" + stream.primary_key = None + stream.cursor_field = [] + + facade = FileBasedStreamFacade.create_from_stream(stream, self._source, self._logger, _ANY_STATE, self._cursor) + assert facade._abstract_stream._primary_key == [] + + def test_create_from_stream_with_composite_primary_key(self): + stream = Mock() + stream.name = "stream" + stream.primary_key = ["id", "name"] + stream.cursor_field = [] + + facade = FileBasedStreamFacade.create_from_stream(stream, self._source, self._logger, _ANY_STATE, self._cursor) + assert facade._abstract_stream._primary_key == ["id", "name"] + + def test_create_from_stream_with_empty_list_cursor(self): + stream = Mock() + stream.primary_key = "id" + stream.cursor_field = [] + + facade = FileBasedStreamFacade.create_from_stream(stream, self._source, self._logger, _ANY_STATE, self._cursor) + + assert facade.cursor_field == [] + + def test_create_from_stream_raises_exception_if_primary_key_is_nested(self): + stream = Mock() + stream.name = "stream" + stream.primary_key = [["field", "id"]] + + with self.assertRaises(ValueError): + FileBasedStreamFacade.create_from_stream(stream, self._source, self._logger, _ANY_STATE, self._cursor) + + def test_create_from_stream_raises_exception_if_primary_key_has_invalid_type(self): + stream = Mock() + stream.name = "stream" + stream.primary_key = 123 + + with self.assertRaises(ValueError): + FileBasedStreamFacade.create_from_stream(stream, self._source, self._logger, _ANY_STATE, self._cursor) + + def test_create_from_stream_raises_exception_if_cursor_field_is_nested(self): + stream = Mock() + stream.name = "stream" + stream.primary_key = "id" + stream.cursor_field = ["field", "cursor"] + + with self.assertRaises(ValueError): + FileBasedStreamFacade.create_from_stream(stream, self._source, self._logger, _ANY_STATE, self._cursor) + + def test_create_from_stream_with_cursor_field_as_list(self): + stream = Mock() + stream.name = "stream" + stream.primary_key = "id" + stream.cursor_field = ["cursor"] + + facade = FileBasedStreamFacade.create_from_stream(stream, self._source, self._logger, _ANY_STATE, self._cursor) + assert facade.cursor_field == "cursor" + + def test_create_from_stream_none_message_repository(self): + self._stream.name = "stream" + self._stream.primary_key = "id" + self._stream.cursor_field = "cursor" + self._source.message_repository = None + + with self.assertRaises(ValueError): + FileBasedStreamFacade.create_from_stream(self._stream, self._source, self._logger, {}, self._cursor) + + def test_get_error_display_message_no_display_message(self): + self._stream.get_error_display_message.return_value = "display_message" + + facade = FileBasedStreamFacade.create_from_stream(self._stream, self._source, self._logger, _ANY_STATE, self._cursor) + + expected_display_message = None + e = Exception() + + display_message = facade.get_error_display_message(e) + + assert expected_display_message == display_message + + def test_get_error_display_message_with_display_message(self): + self._stream.get_error_display_message.return_value = "display_message" + + facade = FileBasedStreamFacade.create_from_stream(self._stream, self._source, self._logger, _ANY_STATE, self._cursor) + + expected_display_message = "display_message" + e = ExceptionWithDisplayMessage("display_message") + + display_message = facade.get_error_display_message(e) + + assert expected_display_message == display_message + + +@pytest.mark.parametrize( + "exception, expected_display_message", + [ + pytest.param(Exception("message"), None, id="test_no_display_message"), + pytest.param(ExceptionWithDisplayMessage("message"), "message", id="test_no_display_message"), + ], +) +def test_get_error_display_message(exception, expected_display_message): + stream = Mock() + legacy_stream = Mock() + cursor = Mock(spec=Cursor) + facade = FileBasedStreamFacade(stream, legacy_stream, cursor, Mock().Mock(), Mock()) + + display_message = facade.get_error_display_message(exception) + + assert display_message == expected_display_message diff --git a/airbyte-cdk/python/unit_tests/sources/file_based/test_scenarios.py b/airbyte-cdk/python/unit_tests/sources/file_based/test_scenarios.py index 747d22a31a1f..7bd256404205 100644 --- a/airbyte-cdk/python/unit_tests/sources/file_based/test_scenarios.py +++ b/airbyte-cdk/python/unit_tests/sources/file_based/test_scenarios.py @@ -73,8 +73,21 @@ def _verify_read_output(output: EntrypointOutput, scenario: TestScenario[Abstrac records, log_messages = output.records_and_state_messages, output.logs logs = [message.log for message in log_messages if message.log.level.value in scenario.log_levels] expected_records = scenario.expected_records + + if expected_records is None: + return + assert len(records) == len(expected_records) - for actual, expected in zip(records, expected_records): + + sorted_expected_records = sorted( + filter(lambda e: "data" in e, expected_records), + key=lambda x: ",".join(f"{k}={v}" for k, v in sorted(x["data"].items(), key=lambda x: x[0]) if k != "emitted_at"), + ) + sorted_records = sorted( + filter(lambda r: r.record, records), + key=lambda x: ",".join(f"{k}={v}" for k, v in sorted(x.record.data.items(), key=lambda x: x[0]) if k != "emitted_at"), + ) + for actual, expected in zip(sorted_records, sorted_expected_records): if actual.record: assert len(actual.record.data) == len(expected["data"]) for key, value in actual.record.data.items(): @@ -83,8 +96,11 @@ def _verify_read_output(output: EntrypointOutput, scenario: TestScenario[Abstrac else: assert value == expected["data"][key] assert actual.record.stream == expected["stream"] - elif actual.state: - assert actual.state.data == expected + + expected_states = filter(lambda e: "data" not in e, expected_records) + states = filter(lambda r: r.state, records) + for actual, expected in zip(states, expected_states): # states should be emitted in sorted order + assert actual.state.data == expected if scenario.expected_logs: read_logs = scenario.expected_logs.get("read") @@ -129,7 +145,7 @@ def verify_check(capsys: CaptureFixture[str], tmp_path: PosixPath, scenario: Tes output = check(capsys, tmp_path, scenario) if expected_msg: # expected_msg is a string. what's the expected value field? - assert expected_msg.value in output["message"] # type: ignore + assert expected_msg in output["message"] # type: ignore assert output["status"] == scenario.expected_check_status else: diff --git a/airbyte-cdk/python/unit_tests/sources/streams/concurrent/scenarios/stream_facade_builder.py b/airbyte-cdk/python/unit_tests/sources/streams/concurrent/scenarios/stream_facade_builder.py index 30ec297b0b4f..ee6e27e9ccbf 100644 --- a/airbyte-cdk/python/unit_tests/sources/streams/concurrent/scenarios/stream_facade_builder.py +++ b/airbyte-cdk/python/unit_tests/sources/streams/concurrent/scenarios/stream_facade_builder.py @@ -11,6 +11,7 @@ from airbyte_cdk.sources.concurrent_source.thread_pool_manager import ThreadPoolManager from airbyte_cdk.sources.connector_state_manager import ConnectorStateManager from airbyte_cdk.sources.message import InMemoryMessageRepository, MessageRepository +from airbyte_cdk.sources.source import TState from airbyte_cdk.sources.streams import Stream from airbyte_cdk.sources.streams.concurrent.adapters import StreamFacade from airbyte_cdk.sources.streams.concurrent.cursor import ConcurrentCursor, CursorField, NoopCursor @@ -123,6 +124,6 @@ def set_input_state(self, state: List[Mapping[str, Any]]) -> "StreamFacadeSource self._input_state = state return self - def build(self, configured_catalog: Optional[Mapping[str, Any]]) -> StreamFacadeSource: + def build(self, configured_catalog: Optional[Mapping[str, Any]], config: Optional[Mapping[str, Any]], state: Optional[TState]) -> StreamFacadeSource: threadpool = concurrent.futures.ThreadPoolExecutor(max_workers=self._max_workers, thread_name_prefix="workerpool") - return StreamFacadeSource(self._streams, threadpool, self._cursor_field, self._cursor_boundaries, self._input_state) + return StreamFacadeSource(self._streams, threadpool, self._cursor_field, self._cursor_boundaries, state) diff --git a/airbyte-cdk/python/unit_tests/sources/streams/concurrent/scenarios/thread_based_concurrent_stream_source_builder.py b/airbyte-cdk/python/unit_tests/sources/streams/concurrent/scenarios/thread_based_concurrent_stream_source_builder.py index 943aea30dbba..87a65ea6efd8 100644 --- a/airbyte-cdk/python/unit_tests/sources/streams/concurrent/scenarios/thread_based_concurrent_stream_source_builder.py +++ b/airbyte-cdk/python/unit_tests/sources/streams/concurrent/scenarios/thread_based_concurrent_stream_source_builder.py @@ -119,7 +119,7 @@ def __init__(self): self._streams: List[DefaultStream] = [] self._message_repository = None - def build(self, configured_catalog: Optional[Mapping[str, Any]]) -> ConcurrentCdkSource: + def build(self, configured_catalog: Optional[Mapping[str, Any]], _, __) -> ConcurrentCdkSource: return ConcurrentCdkSource(self._streams, self._message_repository, 1, 1) def set_streams(self, streams: List[DefaultStream]) -> "ConcurrentSourceBuilder":