diff --git a/pulp-glue/pulp_glue/common/authentication.py b/pulp-glue/pulp_glue/common/authentication.py index 8fd3a9396..2548fc35d 100644 --- a/pulp-glue/pulp_glue/common/authentication.py +++ b/pulp-glue/pulp_glue/common/authentication.py @@ -12,25 +12,29 @@ class OAuth2ClientCredentialsAuth(requests.auth.AuthBase): def __init__( self, - client_id: str = "", - client_secret: str = "", - token_url: str = "", - scopes: t.List[str] = [], + client_id: str, + client_secret: str, + token_url: str, + scopes: t.List[str], ): self.client_id = client_id self.client_secret = client_secret self.token_url = token_url self.scopes = scopes - self.token: t.Dict[t.Any, t.Any] = {} + + self.access_token: t.Optional[str] = None + self.expire_at: t.Optional[datetime] = None def __call__(self, request: requests.PreparedRequest) -> requests.PreparedRequest: - if self.is_token_expired(): + if self.expire_at is None or self.expire_at < datetime.now(): self.retrieve_token() - access_token = self.token.get("access_token") - request.headers["Authorization"] = f"Bearer {access_token}" + assert self.access_token is not None + + request.headers["Authorization"] = f"Bearer {self.access_token}" - request.register_hook("response", self.handle401) # type: ignore + # Call to untyped function "register_hook" in typed context + request.register_hook("response", self.handle401) # type: ignore[no-untyped-call] return request @@ -42,27 +46,35 @@ def handle401( if response.status_code != 401: return response - if self.is_token_expired(): - self.retrieve_token() + # If we get this far, probably the token is not valid anymore. - response.content - prep = response.request.copy() + # Try to reach for a new token once. + self.retrieve_token() + + assert self.access_token is not None - access_token = self.token.get("access_token") - prep.headers["Authorization"] = f"Bearer {access_token}" + # Consume content and release the original connection + # to allow our new request to reuse the same one. + response.content + response.close() + prepared_new_request = response.request.copy() - prep.deregister_hook("response", self.handle401) + prepared_new_request.headers["Authorization"] = f"Bearer {self.access_token}" - _response: requests.Response = response.connection.send(prep, **kwargs) # type: ignore - _response.history.append(response) - _response.request = prep + # Avoid to enter into an infinity loop. + # Call to untyped function "deregister_hook" in typed context + prepared_new_request.deregister_hook( # type: ignore[no-untyped-call] + "response", self.handle401 + ) - return _response + # "Response" has no attribute "connection" + new_response: requests.Response = response.connection.send( # type: ignore[attr-defined] + prepared_new_request, **kwargs + ) + new_response.history.append(response) + new_response.request = prepared_new_request - def is_token_expired(self) -> bool: - if self.token: - return self.token["expires_at"] < datetime.now() - return True + return new_response def retrieve_token(self) -> None: data = { @@ -76,5 +88,6 @@ def retrieve_token(self) -> None: response.raise_for_status() - self.token = response.json() - self.token["expires_at"] = datetime.now() + timedelta(seconds=self.token["expires_in"]) + token = response.json() + self.expire_at = datetime.now() + timedelta(seconds=token["expires_in"]) + self.access_token = token["access_token"]