Skip to content

Commit

Permalink
Manage per partition twilio state
Browse files Browse the repository at this point in the history
  • Loading branch information
mkrawc committed Dec 11, 2024
1 parent 71592bf commit e542900
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def __init__(
self._slice_step = slice_step and pendulum.duration(days=slice_step)
self._start_date = start_date if start_date is not None else "1970-01-01T00:00:00Z"
self._lookback_window = lookback_window
self._cursor_value = None
self._state = {"states": []}

@property
def slice_step(self):
Expand All @@ -156,29 +156,26 @@ def upper_boundary_filter_field(self) -> str:

@property
def state(self) -> Mapping[str, Any]:
if self._cursor_value:
return {
self.cursor_field: self._cursor_value,
}

return {}
return self._state

@state.setter
def state(self, value: MutableMapping[str, Any]):
if self._lookback_window and value.get(self.cursor_field):
new_start_date = (
pendulum.parse(value[self.cursor_field]) - pendulum.duration(minutes=self._lookback_window)
).to_iso8601_string()
if new_start_date > self._start_date:
value[self.cursor_field] = new_start_date
self._cursor_value = value.get(self.cursor_field)

def generate_date_ranges(self) -> Iterable[Optional[MutableMapping[str, Any]]]:
if self._lookback_window:
lookback_duration = pendulum.duration(minutes=self._lookback_window)
for state in value.get("states", []):
cursor = state.get("cursor", {})
if self.cursor_field in cursor:
new_start_date = (pendulum.parse(cursor[self.cursor_field]) - lookback_duration).to_iso8601_string()
if new_start_date > self._start_date:
cursor[self.cursor_field] = new_start_date
self._state = value

def generate_date_ranges(self, partition: MutableMapping[str, Any]) -> Iterable[Optional[MutableMapping[str, Any]]]:
def align_to_dt_format(dt: DateTime) -> DateTime:
return pendulum.parse(dt.format(self.time_filter_template))

end_datetime = pendulum.now("utc")
start_datetime = min(end_datetime, pendulum.parse(self.state.get(self.cursor_field, self._start_date)))
start_datetime = min(end_datetime, pendulum.parse(self._get_partition_state(partition).get(self.cursor_field, self._start_date)))
current_start = start_datetime
current_end = start_datetime
# Aligning to a datetime format is done to avoid the following scenario:
Expand All @@ -196,10 +193,10 @@ def align_to_dt_format(dt: DateTime) -> DateTime:
current_start = current_end + self.slice_granularity

def stream_slices(
self, sync_mode: SyncMode, cursor_field: List[str] = None, stream_state: Mapping[str, Any] = None
self, sync_mode: SyncMode, cursor_field: List[str] = None, stream_state: StreamSlice = None
) -> Iterable[Optional[Mapping[str, Any]]]:
for super_slice in super().stream_slices(sync_mode=sync_mode, cursor_field=cursor_field, stream_state=stream_state):
for dt_range in self.generate_date_ranges():
for dt_range in self.generate_date_ranges(super_slice.partition if super_slice else {}):
yield StreamSlice(partition=super_slice.partition if super_slice else {}, cursor_slice=dt_range)

def request_params(
Expand All @@ -221,15 +218,32 @@ def read_records(
self,
sync_mode: SyncMode,
cursor_field: List[str] = None,
stream_slice: Mapping[str, Any] = None,
stream_slice: StreamSlice = None,
stream_state: Mapping[str, Any] = None,
) -> Iterable[Mapping[str, Any]]:
if stream_slice is None:
stream_slice = StreamSlice(partition={}, cursor_slice={})
for record in super().read_records(sync_mode, cursor_field, stream_slice, stream_state):
record[self.cursor_field] = pendulum.parse(record[self.cursor_field], strict=False).to_iso8601_string()
if record[self.cursor_field] >= self.state.get(self.cursor_field, self._start_date):
self._cursor_value = record[self.cursor_field]
if record[self.cursor_field] >= self._get_partition_state(stream_slice.partition).get(self.cursor_field, self._start_date):
self._state = self._update_partition_state(stream_slice.partition, {self.cursor_field: record[self.cursor_field]})
yield record

def _update_partition_state(self, partition: Mapping[str, Any], cursor: Mapping[str, Any]) -> Mapping[str, Any]:
states = copy.deepcopy(self._state.get("states", []))
for state in states:
if state.get("partition") == partition:
state.update({"cursor": cursor})
return self._state
states.append({"partition": partition, "cursor": cursor})
return {"states": states}

def _get_partition_state(self, partition: Mapping[str, Any]) -> Mapping[str, Any]:
for state in self._state.get("states", []):
if state.get("partition") == partition:
return state.get("cursor", {})
return {}


class TwilioNestedStream(TwilioStream):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def test_request_params(self, stream_cls, stream_slice, next_page_token, expecte
def test_read_records(self, stream_cls, record, expected):
stream = stream_cls(**self.CONFIG)
with patch.object(HttpStream, "read_records", return_value=record):
result = stream.read_records(sync_mode=None)
result = stream.read_records(sync_mode=None, stream_slice=StreamSlice(partition={}, cursor_slice={}))
assert list(result) == expected

@pytest.mark.parametrize(
Expand Down Expand Up @@ -231,17 +231,42 @@ def test_stream_slices(self, mocker, stream_cls, parent_cls_records, extra_slice
(
(
Messages,
{"date_sent": "2022-11-13 23:39:00"},
{
"states": [
{
"partition": {"key": "value"},
"cursor": {"date_sent": "2022-11-13 23:39:00"},
}
]
},
[
{"DateSent>": "2022-11-13 23:39:00Z", "DateSent<": "2022-11-14 23:39:00Z"},
{"DateSent>": "2022-11-14 23:39:00Z", "DateSent<": "2022-11-15 23:39:00Z"},
{"DateSent>": "2022-11-15 23:39:00Z", "DateSent<": "2022-11-16 12:03:11Z"},
],
),
(UsageRecords, {"start_date": "2021-11-16 00:00:00"}, [{"StartDate": "2021-11-16", "EndDate": "2022-11-16"}]),
(
UsageRecords,
{
"states": [
{
"partition": {"key": "value"},
"cursor": {"start_date": "2021-11-16 00:00:00"},
}
]
},
[{"StartDate": "2021-11-16", "EndDate": "2022-11-16"}],
),
(
Recordings,
{"date_created": "2021-11-16 00:00:00"},
{
"states": [
{
"partition": {"key": "value"},
"cursor": {"date_created": "2021-11-16 00:00:00"},
}
]
},
[
{"DateCreated>": "2021-11-16 00:00:00Z", "DateCreated<": "2022-11-16 00:00:00Z"},
{"DateCreated>": "2022-11-16 00:00:00Z", "DateCreated<": "2022-11-16 12:03:11Z"},
Expand All @@ -252,7 +277,7 @@ def test_stream_slices(self, mocker, stream_cls, parent_cls_records, extra_slice
def test_generate_dt_ranges(self, stream_cls, state, expected_dt_ranges):
stream = stream_cls(authenticator=TEST_CONFIG.get("authenticator"), start_date="2000-01-01 00:00:00")
stream.state = state
dt_ranges = list(stream.generate_date_ranges())
dt_ranges = list(stream.generate_date_ranges({"key": "value"}))
assert dt_ranges == expected_dt_ranges


Expand Down

0 comments on commit e542900

Please sign in to comment.