Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into baz/cdk/test-read-ref…
Browse files Browse the repository at this point in the history
…actor
  • Loading branch information
bazarnov committed Feb 13, 2025
2 parents 499b6a0 + 522caab commit 9da64f1
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 4 deletions.
15 changes: 12 additions & 3 deletions airbyte_cdk/sources/declarative/extractors/record_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class RecordSelector(HttpSelector):
_name: Union[InterpolatedString, str] = field(init=False, repr=False, default="")
record_filter: Optional[RecordFilter] = None
transformations: List[RecordTransformation] = field(default_factory=lambda: [])
transform_before_filtering: bool = False

def __post_init__(self, parameters: Mapping[str, Any]) -> None:
self._parameters = parameters
Expand Down Expand Up @@ -104,9 +105,17 @@ def filter_and_transform(
Until we decide to move this logic away from the selector, we made this method public so that users like AsyncJobRetriever could
share the logic of doing transformations on a set of records.
"""
filtered_data = self._filter(all_data, stream_state, stream_slice, next_page_token)
transformed_data = self._transform(filtered_data, stream_state, stream_slice)
normalized_data = self._normalize_by_schema(transformed_data, schema=records_schema)
if self.transform_before_filtering:
transformed_data = self._transform(all_data, stream_state, stream_slice)
transformed_filtered_data = self._filter(
transformed_data, stream_state, stream_slice, next_page_token
)
else:
filtered_data = self._filter(all_data, stream_state, stream_slice, next_page_token)
transformed_filtered_data = self._transform(filtered_data, stream_state, stream_slice)
normalized_data = self._normalize_by_schema(
transformed_filtered_data, schema=records_schema
)
for data in normalized_data:
yield Record(data=data, stream_name=self.name, associated_slice=stream_slice)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2415,6 +2415,8 @@ def create_record_selector(
if model.record_filter
else None
)

transform_before_filtering = False
if client_side_incremental_sync:
record_filter = ClientSideIncrementalRecordFilterDecorator(
config=config,
Expand All @@ -2424,6 +2426,8 @@ def create_record_selector(
else None,
**client_side_incremental_sync,
)
transform_before_filtering = True

schema_normalization = (
TypeTransformer(SCHEMA_TRANSFORMER_TYPE_MAPPING[model.schema_normalization])
if isinstance(model.schema_normalization, SchemaNormalizationModel)
Expand All @@ -2438,6 +2442,7 @@ def create_record_selector(
transformations=transformations or [],
schema_normalization=schema_normalization,
parameters=model.parameters or {},
transform_before_filtering=transform_before_filtering,
)

@staticmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#

import json
from unittest.mock import Mock, call
from unittest.mock import MagicMock, Mock, call

import pytest
import requests
Expand Down Expand Up @@ -220,3 +220,91 @@ def create_schema():
"field_float": {"type": "number"},
},
}


@pytest.mark.parametrize("transform_before_filtering", [True, False])
def test_transform_before_filtering(transform_before_filtering):
"""
Verify that when transform_before_filtering=True, records are modified before
filtering. When False, the filter sees the original record data first.
"""

# 1) Our response body with 'myfield' set differently
# The first record has myfield=0 (needs transformation to pass)
# The second record has myfield=999 (already passes the filter)
body = {"data": [{"id": 1, "myfield": 0}, {"id": 2, "myfield": 999}]}

# 2) A response object
response = requests.Response()
response._content = json.dumps(body).encode("utf-8")

# 3) A simple extractor pulling records from 'data'
extractor = DpathExtractor(
field_path=["data"], decoder=JsonDecoder(parameters={}), config={}, parameters={}
)

# 4) A filter that keeps only records whose 'myfield' == 999
# i.e.: "{{ record['myfield'] == 999 }}"
record_filter = RecordFilter(
config={},
condition="{{ record['myfield'] == 999 }}",
parameters={},
)

# 5) A transformation that sets 'myfield' to 999
# We'll attach it to a mock so we can confirm how many times it was called
transformation_mock = MagicMock(spec=RecordTransformation)

def transformation_side_effect(record, config, stream_state, stream_slice):
record["myfield"] = 999

transformation_mock.transform.side_effect = transformation_side_effect

# 6) Create a RecordSelector with transform_before_filtering set from our param
record_selector = RecordSelector(
extractor=extractor,
config={},
name="test_stream",
record_filter=record_filter,
transformations=[transformation_mock],
schema_normalization=TypeTransformer(TransformConfig.NoTransform),
transform_before_filtering=transform_before_filtering,
parameters={},
)

# 7) Collect records
stream_slice = StreamSlice(partition={}, cursor_slice={})
actual_records = list(
record_selector.select_records(
response=response,
records_schema={}, # not using schema in this test
stream_state={},
stream_slice=stream_slice,
next_page_token=None,
)
)

# 8) Assert how many records survive
if transform_before_filtering:
# Both records become myfield=999 BEFORE the filter => both pass
assert len(actual_records) == 2
# The transformation should be called 2x (once per record)
assert transformation_mock.transform.call_count == 2
else:
# The first record is myfield=0 when the filter sees it => filter excludes it
# The second record is myfield=999 => filter includes it
assert len(actual_records) == 1
# The transformation occurs only on that single surviving record
# (the filter is done first, so the first record is already dropped)
assert transformation_mock.transform.call_count == 1

# 9) Check final record data
# If transform_before_filtering=True => we have records [1,2] both with myfield=999
# If transform_before_filtering=False => we have record [2] with myfield=999
final_record_data = [r.data for r in actual_records]
if transform_before_filtering:
assert all(record["myfield"] == 999 for record in final_record_data)
assert sorted([r["id"] for r in final_record_data]) == [1, 2]
else:
assert final_record_data[0]["id"] == 2
assert final_record_data[0]["myfield"] == 999
Original file line number Diff line number Diff line change
Expand Up @@ -1194,6 +1194,8 @@ def test_client_side_incremental():
stream.retriever.record_selector.record_filter, ClientSideIncrementalRecordFilterDecorator
)

assert stream.retriever.record_selector.transform_before_filtering == True


def test_client_side_incremental_with_partition_router():
content = """
Expand Down Expand Up @@ -1274,6 +1276,7 @@ def test_client_side_incremental_with_partition_router():
assert isinstance(
stream.retriever.record_selector.record_filter, ClientSideIncrementalRecordFilterDecorator
)
assert stream.retriever.record_selector.transform_before_filtering == True
assert isinstance(
stream.retriever.record_selector.record_filter._cursor,
PerPartitionWithGlobalCursor,
Expand Down

0 comments on commit 9da64f1

Please sign in to comment.