Skip to content

Commit

Permalink
use a facade and a legacy adapter
Browse files Browse the repository at this point in the history
  • Loading branch information
girarda committed Sep 19, 2023
1 parent 3f54c36 commit cff94f3
Show file tree
Hide file tree
Showing 5 changed files with 102 additions and 20 deletions.
1 change: 1 addition & 0 deletions airbyte-cdk/python/airbyte_cdk/sources/abstract_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def read(
# TODO assert all streams exist in the connector
# get the streams once in case the connector needs to make any queries to generate them
stream_instances = {s.name: s for s in self.streams(config)}
logger.info(f"Streams detected by {self.name}: {list(stream_instances.keys())}")
state_manager = ConnectorStateManager(stream_instance_map=stream_instances, state=state)
self._stream_to_instance_map = stream_instances
with create_timer(self.name) as timer:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
#
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#

import logging
from abc import ABC, abstractmethod
from typing import Iterable

from airbyte_cdk.models import AirbyteMessage
from airbyte_cdk.models import Type as MessageType
from airbyte_cdk.sources.utils import casing
from airbyte_cdk.sources.utils.types import StreamData


Expand All @@ -22,3 +25,24 @@ def read(
:param internal_config:
:return: The stream's records
"""

@property
def name(self) -> str:
"""
:return: Stream name. By default this is the implementing class name, but it can be overridden as needed.
"""
return casing.camel_to_snake(self.__class__.__name__)

@property
def logger(self) -> logging.Logger:
return logging.getLogger(f"airbyte.streams.{self.name}")

@staticmethod
# FIXME: need to move this!
def is_record(record_data_or_message: StreamData) -> bool:
if isinstance(record_data_or_message, dict):
return True
elif isinstance(record_data_or_message, AirbyteMessage):
return bool(record_data_or_message.type == MessageType.RECORD)
else:
return False
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from airbyte_cdk.sources.streams.partitions.legacy import LegacyPartitionGenerator
from airbyte_cdk.sources.streams.partitions.partition_generator import PartitionGenerator
from airbyte_cdk.sources.utils.schema_helpers import InternalConfig
from airbyte_cdk.sources.utils.slice_logger import SliceLogger
from airbyte_cdk.sources.utils.types import StreamData


Expand Down Expand Up @@ -54,9 +53,6 @@ def create_from_legacy_stream(cls, stream: Stream, max_workers: int) -> "Concurr
max_workers=max_workers,
name=stream.name,
json_schema=stream.get_json_schema(),
availability_strategy=AvailabilityStrategyLegacyAdapter(stream, stream.availability_strategy)
if stream.availability_strategy
else None,
primary_key=stream.primary_key,
cursor_field=stream.cursor_field,
)
Expand All @@ -67,7 +63,6 @@ def __init__(
max_workers: int,
name: str,
json_schema: Mapping[str, Any],
availability_strategy: Optional[AvailabilityStrategy],
primary_key: Optional[Union[str, List[str], List[List[str]]]],
cursor_field: Union[str, List[str]],
):
Expand All @@ -76,18 +71,13 @@ def __init__(
self._threadpool = concurrent.futures.ThreadPoolExecutor(max_workers=self._max_workers, thread_name_prefix="workerpool")
self._name = name
self._json_schema = json_schema
self._availability_strategy = availability_strategy
self._primary_key = primary_key
self._cursor_field = cursor_field

def read(
self,
cursor_field: Optional[List[str]],
logger: logging.Logger,
slice_logger: SliceLogger,
internal_config: InternalConfig = InternalConfig(),
) -> Iterable[StreamData]:
logger.debug(f"Processing stream slices for {self.name} (sync_mode: full_refresh)")
def read(self) -> Iterable[StreamData]:
# FIXME
internal_config = InternalConfig()
self.logger.debug(f"Processing stream slices for {self.name} (sync_mode: full_refresh)")
total_records_counter = 0
futures = []
partition_generator = ConcurrentPartitionGenerator()
Expand Down Expand Up @@ -115,10 +105,6 @@ def read(
partition = partition_generator.get_next()
if partition is not None:
futures.append(self._threadpool.submit(partition_reader.process_partition, partition))
if slice_logger.should_log_slice_message(logger):
# FIXME: This is creating slice log messages for parity with the synchronous implementation
# but these cannot be used by the connector builder to build slices because they can be unordered
yield slice_logger.create_slice_log_message(partition.to_slice())
self._check_for_errors(futures)

def _is_done(self, futures: List[Future[Any]]) -> bool:
Expand Down
45 changes: 45 additions & 0 deletions airbyte-cdk/python/airbyte_cdk/sources/streams/stream_facade.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
#
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#
from typing import Any, Iterable, List, Mapping, Optional, Union

from airbyte_cdk.models import SyncMode
from airbyte_cdk.sources.streams import Stream
from airbyte_cdk.sources.streams.abstract_stream import AbstractStream
from airbyte_cdk.sources.utils.types import StreamData


class StreamFacade(Stream):
def __init__(self, stream: AbstractStream):
self._stream = stream

@property
def name(self) -> str:
return self._stream.name

def read_records(
self,
sync_mode: SyncMode,
cursor_field: Optional[List[str]] = None,
stream_slice: Optional[Mapping[str, Any]] = None,
stream_state: Optional[Mapping[str, Any]] = None,
) -> Iterable[StreamData]:
"""
This method should be overridden by subclasses to read records based on the inputs
"""
return self._stream.read()

@property
def primary_key(self) -> Optional[Union[str, List[str], List[List[str]]]]:
return self._stream

@property
def cursor_field(self) -> Union[str, List[str]]:
pass

def get_json_schema(self) -> Mapping[str, Any]:
pass

@property
def source_defined_cursor(self) -> bool:
pass
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@
from airbyte_cdk.models import FailureType
from airbyte_cdk.sources import AbstractSource
from airbyte_cdk.sources.streams import Stream
from airbyte_cdk.sources.streams.concurrent.concurrent_stream import ConcurrentStream
from airbyte_cdk.sources.streams.http.auth import TokenAuthenticator
from airbyte_cdk.sources.streams.partitions.legacy import LegacyPartitionGenerator
from airbyte_cdk.sources.streams.stream_facade import StreamFacade
from airbyte_cdk.utils import AirbyteTracedException
from source_stripe.streams import (
CheckoutSessionsLineItems,
Expand All @@ -30,6 +33,27 @@
)


class ConcurrentStreamAdapter(Stream):
@classmethod
def create_from_legacy_stream(cls, stream: Stream, max_workers: int) -> Stream:
"""
Create a ConcurrentStream from a legacy Stream.
:param stream:
:param max_workers:
:return:
"""
return StreamFacade(
ConcurrentStream(
partition_generator=LegacyPartitionGenerator(stream),
max_workers=max_workers,
name=stream.name,
json_schema=stream.get_json_schema(),
primary_key=stream.primary_key,
cursor_field=stream.cursor_field,
)
)


class SourceStripe(AbstractSource):
@staticmethod
def validate_and_fill_with_defaults(config: MutableMapping) -> MutableMapping:
Expand Down Expand Up @@ -159,7 +183,7 @@ def streams(self, config: Mapping[str, Any]) -> List[Stream]:
],
**args,
)
return [
legacy_streams = [
CheckoutSessionsLineItems(**incremental_args),
CustomerBalanceTransactions(**args),
Events(**incremental_args),
Expand Down Expand Up @@ -415,3 +439,5 @@ def streams(self, config: Mapping[str, Any]) -> List[Stream]:
**args,
),
]
# return legacy_streams
return [ConcurrentStreamAdapter.create_from_legacy_stream(stream, 6) for stream in legacy_streams]

0 comments on commit cff94f3

Please sign in to comment.