Skip to content

Commit

Permalink
add incrementalmixin
Browse files Browse the repository at this point in the history
  • Loading branch information
ChristoGrab committed May 30, 2024
1 parent 6db379d commit 8bbe162
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import requests
from airbyte_cdk.models import SyncMode
from airbyte_cdk.sources import Source
from airbyte_cdk.sources.streams import Stream
from airbyte_cdk.sources.streams import IncrementalMixin, Stream
from airbyte_cdk.sources.streams.http import HttpStream, HttpSubStream
from airbyte_cdk.sources.streams.http.availability_strategy import HttpAvailabilityStrategy
from airbyte_cdk.sources.streams.http.exceptions import UserDefinedBackoffException
Expand Down Expand Up @@ -165,7 +165,7 @@ def dict(self, **kwargs):
return {pydantic.utils.ROOT_KEY: self.value}


class IncrementalNotionStream(NotionStream, ABC):
class IncrementalNotionStream(NotionStream, IncrementalMixin, ABC):

cursor_field = "last_edited_time"

Expand All @@ -180,6 +180,14 @@ def __init__(self, obj_type: Optional[str] = None, **kwargs):
# object type for search filtering, either "page" or "database" if not None
self.obj_type = obj_type

@property
def state(self) -> MutableMapping[str, Any]:
return self._state

@state.setter
def state(self, value: MutableMapping[str, Any]):
self._state = value

def path(self, **kwargs) -> str:
return "search"

Expand All @@ -202,8 +210,14 @@ def request_body_json(self, next_page_token: Mapping[str, Any] = None, **kwargs)
def read_records(self, sync_mode: SyncMode, stream_state: Mapping[str, Any] = None, **kwargs) -> Iterable[Mapping[str, Any]]:
if sync_mode == SyncMode.full_refresh:
stream_state = None

self.state = stream_state or {}


try:
yield from super().read_records(sync_mode, stream_state=stream_state, **kwargs)
for record in super().read_records(sync_mode, stream_state=stream_state, **kwargs):
self.state = self._get_updated_state(self.state, record)
yield record
except UserDefinedBackoffException as e:
message = self.check_invalid_start_cursor(e.response)
if message:
Expand All @@ -221,7 +235,7 @@ def parse_response(self, response: requests.Response, stream_state: Mapping[str,
if (not stream_state or record_lmd >= state_lmd) and record_lmd >= self.start_date:
yield record

def get_updated_state(
def _get_updated_state(
self,
current_stream_state: MutableMapping[str, Any],
latest_record: Mapping[str, Any],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -189,18 +189,18 @@ def test_get_updated_state(stream):
"latest_record": {"last_edited_time": "2021-10-20T00:00:00.000Z"},
}
expected_state = "2021-10-10T00:00:00.000Z"
state = stream.get_updated_state(**inputs)
state = stream._get_updated_state(**inputs)
assert state["last_edited_time"].value == expected_state

inputs = {"current_stream_state": state, "latest_record": {"last_edited_time": "2021-10-30T00:00:00.000Z"}}
state = stream.get_updated_state(**inputs)
state = stream._get_updated_state(**inputs)
assert state["last_edited_time"].value == expected_state

# after stream sync is finished, state should output the max cursor time
stream.is_finished = True
inputs = {"current_stream_state": state, "latest_record": {"last_edited_time": "2021-10-10T00:00:00.000Z"}}
expected_state = "2021-10-30T00:00:00.000Z"
state = stream.get_updated_state(**inputs)
state = stream._get_updated_state(**inputs)
assert state["last_edited_time"].value == expected_state


Expand Down

0 comments on commit 8bbe162

Please sign in to comment.