From 5b8421e7227059583a02bee5086f6e8b55c5e0f5 Mon Sep 17 00:00:00 2001 From: Anton Burnashev Date: Mon, 9 Dec 2024 17:49:31 +0100 Subject: [PATCH] Fix validation error in for custom auth classes --- dlt/common/typing.py | 15 +++++++++++++++ dlt/sources/rest_api/config_setup.py | 6 +++++- .../configurations/test_custom_auth_config.py | 17 ++++++++++++++++- 3 files changed, 36 insertions(+), 2 deletions(-) diff --git a/dlt/common/typing.py b/dlt/common/typing.py index a3364d1b07..c8080b548d 100644 --- a/dlt/common/typing.py +++ b/dlt/common/typing.py @@ -484,3 +484,18 @@ def decorator( return func return decorator + + +def add_value_to_literal(literal: Any, value: Any) -> None: + """Extends a Literal at runtime with a new value. + + Args: + literal (Type[Any]): Literal to extend + value (Any): Value to add + + """ + type_args = get_args(literal) + + if value not in type_args: + type_args += (value,) + literal.__args__ = type_args diff --git a/dlt/sources/rest_api/config_setup.py b/dlt/sources/rest_api/config_setup.py index d03a4fd59b..3ce81f3aa7 100644 --- a/dlt/sources/rest_api/config_setup.py +++ b/dlt/sources/rest_api/config_setup.py @@ -20,6 +20,7 @@ from dlt.common.configuration import resolve_configuration from dlt.common.schema.utils import merge_columns from dlt.common.utils import update_dict_nested, exclude_keys +from dlt.common.typing import add_value_to_literal from dlt.common import jsonpath from dlt.extract.incremental import Incremental @@ -64,6 +65,7 @@ ResponseActionDict, Endpoint, EndpointResource, + AuthType, ) @@ -153,6 +155,8 @@ def register_auth( ) AUTH_MAP[auth_name] = auth_class + add_value_to_literal(AuthType, auth_name) + def get_auth_class(auth_type: str) -> Type[AuthConfigBase]: try: @@ -285,7 +289,7 @@ def build_resource_dependency_graph( resolved_param_map[resource_name] = None break assert isinstance(endpoint_resource["endpoint"], dict) - # connect transformers to resources via resolved params + # find resolved parameters to connect dependent resources resolved_params = _find_resolved_params(endpoint_resource["endpoint"]) # set of resources in resolved params diff --git a/tests/sources/rest_api/configurations/test_custom_auth_config.py b/tests/sources/rest_api/configurations/test_custom_auth_config.py index 1a5a2e58a3..132bd67e88 100644 --- a/tests/sources/rest_api/configurations/test_custom_auth_config.py +++ b/tests/sources/rest_api/configurations/test_custom_auth_config.py @@ -5,7 +5,7 @@ from dlt.sources import rest_api from dlt.sources.helpers.rest_client.auth import APIKeyAuth, OAuth2ClientCredentials -from dlt.sources.rest_api.typing import ApiKeyAuthConfig, AuthConfig +from dlt.sources.rest_api.typing import ApiKeyAuthConfig, AuthConfig, RESTAPIConfig class CustomOAuth2(OAuth2ClientCredentials): @@ -77,3 +77,18 @@ class NotAuthConfigBase: "not_an_auth_config_base", NotAuthConfigBase # type: ignore ) assert e.match("Invalid auth: NotAuthConfigBase.") + + def test_validate_config_raises_no_error(self, custom_auth_config: AuthConfig) -> None: + rest_api.config_setup.register_auth("custom_oauth_2", CustomOAuth2) + + valid_config: RESTAPIConfig = { + "client": { + "base_url": "https://example.com", + "auth": custom_auth_config, + }, + "resources": ["test"], + } + + rest_api.rest_api_source(valid_config) + + del rest_api.config_setup.AUTH_MAP["custom_oauth_2"]