Skip to content

Commit

Permalink
File-based CDK: make full refresh concurrent
Browse files Browse the repository at this point in the history
  • Loading branch information
clnoll committed Jan 21, 2024
1 parent e3e58cc commit f2c6da2
Show file tree
Hide file tree
Showing 21 changed files with 526 additions and 145 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@

import logging
from abc import abstractmethod
from typing import TYPE_CHECKING, Optional, Tuple
from typing import TYPE_CHECKING, Optional, Tuple, Union

from airbyte_cdk.sources import Source
from airbyte_cdk.sources.streams.availability_strategy import AvailabilityStrategy
from airbyte_cdk.sources.streams.core import Stream

if TYPE_CHECKING:
from airbyte_cdk.sources.file_based.stream import AbstractFileBasedStream
from airbyte_cdk.sources.file_based.stream.concurrent.adapters import FileBasedStreamFacade


class AbstractFileBasedAvailabilityStrategy(AvailabilityStrategy):
Expand All @@ -26,7 +27,7 @@ def check_availability(self, stream: Stream, logger: logging.Logger, _: Optional

@abstractmethod
def check_availability_and_parsability(
self, stream: "AbstractFileBasedStream", logger: logging.Logger, _: Optional[Source]
self, stream: Union["AbstractFileBasedStream", "FileBasedStreamFacade"], logger: logging.Logger, _: Optional[Source]
) -> Tuple[bool, Optional[str]]:
"""
Performs a connection check for the stream, as well as additional checks that
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,18 @@
from collections import Counter
from typing import Any, Iterator, List, Mapping, MutableMapping, Optional, Tuple, Type, Union

from airbyte_cdk.models import AirbyteMessage, AirbyteStateMessage, ConfiguredAirbyteCatalog, ConnectorSpecification
from airbyte_cdk.sources import AbstractSource
from airbyte_cdk.logger import AirbyteLogFormatter, init_logger
from airbyte_cdk.models import (
AirbyteMessage,
AirbyteStateMessage,
ConfiguredAirbyteCatalog,
ConnectorSpecification,
FailureType,
Level,
SyncMode,
)
from airbyte_cdk.sources.concurrent_source.concurrent_source import ConcurrentSource
from airbyte_cdk.sources.concurrent_source.concurrent_source_adapter import ConcurrentSourceAdapter
from airbyte_cdk.sources.file_based.availability_strategy import AbstractFileBasedAvailabilityStrategy, DefaultFileBasedAvailabilityStrategy
from airbyte_cdk.sources.file_based.config.abstract_file_based_spec import AbstractFileBasedSpec
from airbyte_cdk.sources.file_based.config.file_based_stream_config import FileBasedStreamConfig, ValidationPolicy
Expand All @@ -20,19 +30,32 @@
from airbyte_cdk.sources.file_based.file_types.file_type_parser import FileTypeParser
from airbyte_cdk.sources.file_based.schema_validation_policies import DEFAULT_SCHEMA_VALIDATION_POLICIES, AbstractSchemaValidationPolicy
from airbyte_cdk.sources.file_based.stream import AbstractFileBasedStream, DefaultFileBasedStream
from airbyte_cdk.sources.file_based.stream.concurrent.adapters import FileBasedStreamFacade
from airbyte_cdk.sources.file_based.stream.concurrent.cursor import FileBasedNoopCursor
from airbyte_cdk.sources.file_based.stream.cursor import AbstractFileBasedCursor
from airbyte_cdk.sources.file_based.stream.cursor.default_file_based_cursor import DefaultFileBasedCursor
from airbyte_cdk.sources.message.repository import InMemoryMessageRepository, MessageRepository
from airbyte_cdk.sources.source import TState
from airbyte_cdk.sources.streams import Stream
from airbyte_cdk.utils.analytics_message import create_analytics_message
from airbyte_cdk.utils.traced_exception import AirbyteTracedException
from pydantic.error_wrappers import ValidationError

DEFAULT_CONCURRENCY = 100
MAX_CONCURRENCY = 100
INITIAL_N_PARTITIONS = MAX_CONCURRENCY // 2


class FileBasedSource(ConcurrentSourceAdapter, ABC):
concurrency_level = MAX_CONCURRENCY

class FileBasedSource(AbstractSource, ABC):
def __init__(
self,
stream_reader: AbstractFileBasedStreamReader,
spec_class: Type[AbstractFileBasedSpec],
catalog_path: Optional[str] = None,
catalog: Optional[ConfiguredAirbyteCatalog],
config: Optional[Mapping[str, Any]],
state: Optional[TState],
availability_strategy: Optional[AbstractFileBasedAvailabilityStrategy] = None,
discovery_policy: AbstractDiscoveryPolicy = DefaultDiscoveryPolicy(),
parsers: Mapping[Type[Any], FileTypeParser] = default_parsers,
Expand All @@ -41,15 +64,29 @@ def __init__(
):
self.stream_reader = stream_reader
self.spec_class = spec_class
self.config = config
self.catalog = catalog
self.state = state
self.availability_strategy = availability_strategy or DefaultFileBasedAvailabilityStrategy(stream_reader)
self.discovery_policy = discovery_policy
self.parsers = parsers
self.validation_policies = validation_policies
catalog = self.read_catalog(catalog_path) if catalog_path else None
self.stream_schemas = {s.stream.name: s.stream.json_schema for s in catalog.streams} if catalog else {}
self.cursor_cls = cursor_cls
self.logger = logging.getLogger(f"airbyte.{self.name}")
self.logger = init_logger(f"airbyte.{self.name}")
self.errors_collector: FileBasedErrorsCollector = FileBasedErrorsCollector()
self._message_repository: Optional[MessageRepository] = None
concurrent_source = ConcurrentSource.create(
MAX_CONCURRENCY, INITIAL_N_PARTITIONS, self.logger, self._slice_logger, self.message_repository
)
self._state = None
super().__init__(concurrent_source)

@property
def message_repository(self) -> MessageRepository:
if self._message_repository is None:
self._message_repository = InMemoryMessageRepository(Level(AirbyteLogFormatter.level_mapping[self.logger.level]))
return self._message_repository

def check_connection(self, logger: logging.Logger, config: Mapping[str, Any]) -> Tuple[bool, Optional[Any]]:
"""
Expand All @@ -61,7 +98,15 @@ def check_connection(self, logger: logging.Logger, config: Mapping[str, Any]) ->
Otherwise, the "error" object should describe what went wrong.
"""
streams = self.streams(config)
try:
streams = self.streams(config)
except Exception as config_exception:
raise AirbyteTracedException(
internal_message="Please check the logged errors for more information.",
message=FileBasedSourceError.CONFIG_VALIDATION_ERROR.value,
exception=AirbyteTracedException(exception=config_exception),
failure_type=FailureType.config_error,
)
if len(streams) == 0:
return (
False,
Expand All @@ -72,7 +117,7 @@ def check_connection(self, logger: logging.Logger, config: Mapping[str, Any]) ->

errors = []
for stream in streams:
if not isinstance(stream, AbstractFileBasedStream):
if not isinstance(stream, (AbstractFileBasedStream, FileBasedStreamFacade)):
raise ValueError(f"Stream {stream} is not a file-based stream.")
try:
(
Expand All @@ -91,10 +136,24 @@ def streams(self, config: Mapping[str, Any]) -> List[Stream]:
"""
Return a list of this source's streams.
"""
file_based_streams = self._get_file_based_streams(config)

configured_streams: List[Stream] = []

for stream in file_based_streams:
sync_mode = self._get_sync_mode_from_catalog(stream)
if sync_mode == SyncMode.full_refresh:
configured_streams.append(FileBasedStreamFacade.create_from_stream(stream, self, self.logger, None, FileBasedNoopCursor()))
else:
configured_streams.append(stream)

return configured_streams

def _get_file_based_streams(self, config: Mapping[str, Any]) -> List[AbstractFileBasedStream]:
try:
parsed_config = self._get_parsed_config(config)
self.stream_reader.config = parsed_config
streams: List[Stream] = []
streams: List[AbstractFileBasedStream] = []
for stream_config in parsed_config.streams:
self._validate_input_schema(stream_config)
streams.append(
Expand All @@ -115,6 +174,13 @@ def streams(self, config: Mapping[str, Any]) -> List[Stream]:
except ValidationError as exc:
raise ConfigValidationError(FileBasedSourceError.CONFIG_VALIDATION_ERROR) from exc

def _get_sync_mode_from_catalog(self, stream: Stream) -> Optional[SyncMode]:
if self.catalog:
for catalog_stream in self.catalog.streams:
if stream.name == catalog_stream.stream.name:
return catalog_stream.sync_mode
return None

def read(
self,
logger: logging.Logger,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from functools import partial
from io import IOBase
from typing import Any, Callable, Dict, Generator, Iterable, List, Mapping, Optional, Set, Tuple
from uuid import uuid4

from airbyte_cdk.models import FailureType
from airbyte_cdk.sources.file_based.config.csv_format import CsvFormat, CsvHeaderAutogenerated, CsvHeaderUserProvided, InferenceType
Expand Down Expand Up @@ -38,8 +39,10 @@ def read_data(

# Formats are configured individually per-stream so a unique dialect should be registered for each stream.
# We don't unregister the dialect because we are lazily parsing each csv file to generate records
# This will potentially be a problem if we ever process multiple streams concurrently
dialect_name = config.name + DIALECT_NAME
# Give each stream's dialect a unique name; otherwise, when we are doing a concurrent sync we can end up
# with a race condition where a thread attempts to use a dialect before a separate thread has finished
# registering it.
dialect_name = f"{config.name}_{str(uuid4())}_{DIALECT_NAME}"
csv.register_dialect(
dialect_name,
delimiter=config_format.delimiter,
Expand Down
Empty file.
Loading

0 comments on commit f2c6da2

Please sign in to comment.