diff --git a/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py b/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py index 040ac4689..6ed24295a 100644 --- a/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py +++ b/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py @@ -2395,21 +2395,12 @@ def create_http_requester( api_budget = self._api_budget - # Removes QueryProperties components from the interpolated mappings because it has been designed - # to be used by the SimpleRetriever and will be resolved from the provider from the slice directly - # instead of through jinja interpolation - request_parameters: Optional[Union[str, Mapping[str, str]]] - if isinstance(model.request_parameters, Mapping): - request_parameters = self._remove_query_properties(model.request_parameters) - else: - request_parameters = model.request_parameters - request_options_provider = InterpolatedRequestOptionsProvider( request_body=model.request_body, request_body_data=model.request_body_data, request_body_json=model.request_body_json, request_headers=model.request_headers, - request_parameters=request_parameters, + request_parameters=model.request_parameters, # type: ignore # QueryProperties have been removed in `create_simple_retriever` query_properties_key=query_properties_key, config=config, parameters=model.parameters or {}, @@ -3199,7 +3190,8 @@ def _get_url(req: Requester) -> str: query_properties: Optional[QueryProperties] = None query_properties_key: Optional[str] = None - if self._query_properties_in_request_parameters(model.requester): + self._ensure_query_properties_to_model(model.requester) + if self._has_query_properties_in_request_parameters(model.requester): # It is better to be explicit about an error if PropertiesFromEndpoint is defined in multiple # places instead of default to request_parameters which isn't clearly documented if ( @@ -3211,7 +3203,7 @@ def _get_url(req: Requester) -> str: ) query_properties_definitions = [] - for key, request_parameter in model.requester.request_parameters.items(): # type: ignore # request_parameters is already validated to be a Mapping using _query_properties_in_request_parameters() + for key, request_parameter in model.requester.request_parameters.items(): # type: ignore # request_parameters is already validated to be a Mapping using _has_query_properties_in_request_parameters() if isinstance(request_parameter, QueryPropertiesModel): query_properties_key = key query_properties_definitions.append(request_parameter) @@ -3225,6 +3217,16 @@ def _get_url(req: Requester) -> str: query_properties = self._create_component_from_model( model=query_properties_definitions[0], config=config ) + + # Removes QueryProperties components from the interpolated mappings because it has been designed + # to be used by the SimpleRetriever and will be resolved from the provider from the slice directly + # instead of through jinja interpolation + if hasattr(model.requester, "request_parameters") and isinstance( + model.requester.request_parameters, Mapping + ): + model.requester.request_parameters = self._remove_query_properties( + model.requester.request_parameters + ) elif ( hasattr(model.requester, "fetch_properties_from_endpoint") and model.requester.fetch_properties_from_endpoint @@ -3361,7 +3363,7 @@ def _should_limit_slices_fetched(self) -> bool: return bool(self._limit_slices_fetched or self._emit_connector_builder_messages) @staticmethod - def _query_properties_in_request_parameters( + def _has_query_properties_in_request_parameters( requester: Union[HttpRequesterModel, CustomRequesterModel], ) -> bool: if not hasattr(requester, "request_parameters"): @@ -4175,3 +4177,26 @@ def create_grouping_partition_router( deduplicate=model.deduplicate if model.deduplicate is not None else True, config=config, ) + + def _ensure_query_properties_to_model( + self, requester: Union[HttpRequesterModel, CustomRequesterModel] + ) -> None: + """ + For some reason, it seems like CustomRequesterModel request_parameters stays as dictionaries which means that + the other conditions relying on it being QueryPropertiesModel instead of a dict fail. Here, we migrate them to + proper model. + """ + if not hasattr(requester, "request_parameters"): + return + + request_parameters = requester.request_parameters + if request_parameters and isinstance(request_parameters, Dict): + for request_parameter_key in request_parameters.keys(): + request_parameter = request_parameters[request_parameter_key] + if ( + isinstance(request_parameter, Dict) + and request_parameter.get("type") == "QueryProperties" + ): + request_parameters[request_parameter_key] = QueryPropertiesModel.parse_obj( + request_parameter + ) diff --git a/unit_tests/sources/declarative/parsers/test_model_to_component_factory.py b/unit_tests/sources/declarative/parsers/test_model_to_component_factory.py index 4825a9a3f..a8939e5ca 100644 --- a/unit_tests/sources/declarative/parsers/test_model_to_component_factory.py +++ b/unit_tests/sources/declarative/parsers/test_model_to_component_factory.py @@ -1,6 +1,7 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +import json from copy import deepcopy # mypy: ignore-errors @@ -1189,6 +1190,73 @@ def test_incremental_stream_with_custom_retriever_and_partition_router(): ) +def test_stream_with_custom_requester_and_query_properties(requests_mock): + content = """ +a_stream: + type: DeclarativeStream + primary_key: id + schema_loader: + type: InlineSchemaLoader + schema: + $schema: "http://json-schema.org/draft-07/schema" + type: object + properties: + id: + type: string + retriever: + type: SimpleRetriever + name: "{{ parameters['name'] }}" + decoder: + type: JsonDecoder + requester: + type: CustomRequester + class_name: unit_tests.sources.declarative.parsers.testing_components.TestingRequester + name: "{{ parameters['name'] }}" + url_base: "https://api.sendgrid.com/v3/" + path: "path" + http_method: "GET" + request_parameters: + not_query: 1 + query: + type: QueryProperties + property_list: + - id + - field + always_include_properties: + - id + property_chunking: + type: PropertyChunking + property_limit_type: property_count + property_limit: 18 + record_selector: + type: RecordSelector + extractor: + type: DpathExtractor + field_path: ["records"] + $parameters: + name: a_stream +""" + + parsed_manifest = YamlDeclarativeSource._parse(content) + resolved_manifest = resolver.preprocess_manifest(parsed_manifest) + stream_manifest = transformer.propagate_types_and_parameters( + "", resolved_manifest["a_stream"], {} + ) + + stream = factory.create_component( + model_type=DeclarativeStreamModel, component_definition=stream_manifest, config=input_config + ) + requests_mock.get( + "https://api.sendgrid.com/v3/path", + text=json.dumps({"records": [{"id": "1"}]}), + status_code=200, + ) + + x = list(next(stream.generate_partitions()).read()) + + assert len(x) == 1 + + @pytest.mark.parametrize( "use_legacy_state", [ diff --git a/unit_tests/sources/declarative/parsers/testing_components.py b/unit_tests/sources/declarative/parsers/testing_components.py index d37bb9307..88316b521 100644 --- a/unit_tests/sources/declarative/parsers/testing_components.py +++ b/unit_tests/sources/declarative/parsers/testing_components.py @@ -8,12 +8,18 @@ from airbyte_cdk.sources.declarative.extractors import DpathExtractor from airbyte_cdk.sources.declarative.migrations.state_migration import StateMigration from airbyte_cdk.sources.declarative.partition_routers import SubstreamPartitionRouter -from airbyte_cdk.sources.declarative.requesters import RequestOption +from airbyte_cdk.sources.declarative.requesters import HttpRequester, RequestOption from airbyte_cdk.sources.declarative.requesters.error_handlers import DefaultErrorHandler from airbyte_cdk.sources.declarative.requesters.paginators import ( DefaultPaginator, PaginationStrategy, ) +from airbyte_cdk.sources.declarative.requesters.request_options import ( + InterpolatedRequestOptionsProvider, +) +from airbyte_cdk.sources.declarative.requesters.request_options.interpolated_request_options_provider import ( + RequestInput, +) from airbyte_cdk.sources.declarative.retrievers import SimpleRetriever @@ -82,3 +88,20 @@ def migrate(self, stream_state: Mapping[str, Any]) -> Mapping[str, Any]: "updated_at": "2024-02-01T00:00:00.000000+00:00" } return stream_state + + +@dataclass +class TestingRequester(HttpRequester): + request_parameters: Optional[RequestInput] = None + + def __post_init__(self, parameters: Mapping[str, Any]) -> None: + """ + Initializes the request options provider with the provided parameters and any + configured request components like headers, parameters, or bodies. + """ + self.request_options_provider = InterpolatedRequestOptionsProvider( + request_parameters=self.request_parameters, + config=self.config, + parameters=parameters or {}, + ) + super().__post_init__(parameters)