diff --git a/src/anemoi/training/diagnostics/mlflow/auth.py b/src/anemoi/training/diagnostics/mlflow/auth.py index 100a69d3..9d1730ff 100644 --- a/src/anemoi/training/diagnostics/mlflow/auth.py +++ b/src/anemoi/training/diagnostics/mlflow/auth.py @@ -6,7 +6,6 @@ # nor does it submit to any jurisdiction. import json -import logging import os import time from getpass import getpass @@ -16,7 +15,7 @@ from anemoi.utils.config import save_config from requests.exceptions import HTTPError -LOG = logging.getLogger(__name__) +from anemoi.training.utils.logger import get_code_logger class TokenAuth: @@ -28,8 +27,7 @@ def __init__( refresh_expire_days=29, enabled=True, ): - """ - Parameters + """Parameters ---------- url : str URL of the authentication server. @@ -53,6 +51,10 @@ def __init__( self.access_token = None self.access_expires = 0 + # the command line tool adds a default handler to the root logger on runtime, + # so we init our logger here (on runtime, not on import) to avoid duplicate handlers + self.log = get_code_logger(__name__) + def __call__(self): self.authenticate() @@ -87,14 +89,14 @@ def login(self, force_credentials=False, **kwargs): if not self.enabled: return - LOG.info(f"Logging in to {self.url}") + self.log.info(f"Logging in to {self.url}") new_refresh_token = None if not force_credentials and self.refresh_token and self.refresh_expires > time.time(): new_refresh_token = self._token_request(ignore_exc=True).get("refresh_token") if not new_refresh_token: - LOG.info("Please sign in with your credentials.") + self.log.info("Please sign in with your credentials.") username = input("Username: ") password = getpass("Password: ") @@ -106,7 +108,7 @@ def login(self, force_credentials=False, **kwargs): self.refresh_token = new_refresh_token self.save() - LOG.info("Successfully logged in to MLflow. Happy logging!") + self.log.info("Successfully logged in to MLflow. Happy logging!") def authenticate(self, **kwargs): """Check the access token and refresh it if necessary. @@ -139,13 +141,13 @@ def authenticate(self, **kwargs): os.environ["MLFLOW_TRACKING_TOKEN"] = self.access_token - LOG.debug("Access token refreshed.") + self.log.debug("Access token refreshed.") def save(self, **kwargs): """Save the latest refresh token to disk.""" if not self.refresh_token: - LOG.warning("No refresh token to save.") + self.log.warning("No refresh token to save.") return config = { @@ -194,5 +196,5 @@ def _request(self, path, payload): return response_json["response"] except HTTPError as http_err: - LOG.error(f"HTTP error occurred: {http_err}") + self.log.error(f"HTTP error occurred: {http_err}") raise