Skip to content

Commit

Permalink
adds session to OAuth constructor so that we can mock the session to …
Browse files Browse the repository at this point in the history
…run on CI.
  • Loading branch information
willi-mueller committed Jun 20, 2024
1 parent 352b5d7 commit edfff5b
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 1 deletion.
5 changes: 4 additions & 1 deletion dlt/sources/helpers/rest_client/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ def __init__(
client_secret: TSecretStrValue,
access_token_request_data: Dict[str, Any] = None,
default_token_expiration: int = 3600,
session: Annotated[BaseSession, NotResolved()] = None,
) -> None:
super().__init__()
self.access_token_url = access_token_url
Expand All @@ -175,6 +176,8 @@ def __init__(
self.access_token_request_data = access_token_request_data
self.default_token_expiration = default_token_expiration
self.token_expiry: pendulum.DateTime = pendulum.now()
if session is None:
self.session = requests.client.session

def __call__(self, request: PreparedRequest) -> PreparedRequest:
if self.access_token is None or self.is_token_expired():
Expand Down Expand Up @@ -233,7 +236,7 @@ def __post_init__(self) -> None:
self.scopes = self.scopes if isinstance(self.scopes, str) else " ".join(self.scopes)
self.token = None
self.token_expiry: Optional[pendulum.DateTime] = None
# use default system session is not specified
# use default system session unless specified otherwise
if self.session is None:
self.session = requests.client.session

Expand Down
5 changes: 5 additions & 0 deletions tests/sources/helpers/rest_client/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def rest_client_oauth() -> RESTClient:
access_token_url=cast(TSecretStrValue, "https://api.example.com/oauth/token"),
client_id=cast(TSecretStrValue, "test-client-id"),
client_secret=cast(TSecretStrValue, "test-client-secret"),
session=Client().session,
)
return build_rest_client(auth=auth)

Expand All @@ -65,6 +66,7 @@ def rest_client_immediate_oauth_expiry(auth=None) -> RESTClient:
access_token_url=cast(TSecretStrValue, "https://api.example.com/oauth/token-expires-now"),
client_id=cast(TSecretStrValue, "test-client-id"),
client_secret=cast(TSecretStrValue, "test-client-secret"),
session=Client().session,
)
return build_rest_client(auth=credentials_expiring_now)

Expand Down Expand Up @@ -207,6 +209,7 @@ def test_oauth2_client_credentials_flow_wrong_client_id(self, rest_client: RESTC
access_token_url=cast(TSecretStrValue, "https://api.example.com/oauth/token"),
client_id=cast(TSecretStrValue, "invalid-client-id"),
client_secret=cast(TSecretStrValue, "test-client-secret"),
session=Client().session,
)

with pytest.raises(HTTPError) as e:
Expand All @@ -219,6 +222,7 @@ def test_oauth2_client_credentials_flow_wrong_client_secret(self, rest_client: R
access_token_url=cast(TSecretStrValue, "https://api.example.com/oauth/token"),
client_id=cast(TSecretStrValue, "test-client-id"),
client_secret=cast(TSecretStrValue, "invalid-client-secret"),
session=Client().session,
)

with pytest.raises(HTTPError) as e:
Expand Down Expand Up @@ -274,6 +278,7 @@ def build_access_token_request(self) -> Dict[str, Any]:
access_token_request_data={
"account_id": cast(TSecretStrValue, "test-account-id"),
},
session=Client().session,
)

assert auth.build_access_token_request() == {
Expand Down

0 comments on commit edfff5b

Please sign in to comment.