Skip to content

Commit

Permalink
Use code logger
Browse files Browse the repository at this point in the history
  • Loading branch information
gmertes committed Jul 2, 2024
1 parent f4b781c commit 6e8bfcc
Showing 1 changed file with 12 additions and 10 deletions.
22 changes: 12 additions & 10 deletions src/anemoi/training/diagnostics/mlflow/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
# nor does it submit to any jurisdiction.

import json
import logging
import os
import time
from getpass import getpass
Expand All @@ -16,7 +15,7 @@
from anemoi.utils.config import save_config
from requests.exceptions import HTTPError

LOG = logging.getLogger(__name__)
from anemoi.training.utils.logger import get_code_logger


class TokenAuth:
Expand All @@ -28,8 +27,7 @@ def __init__(
refresh_expire_days=29,
enabled=True,
):
"""
Parameters
"""Parameters
----------
url : str
URL of the authentication server.
Expand All @@ -53,6 +51,10 @@ def __init__(
self.access_token = None
self.access_expires = 0

# the command line tool adds a default handler to the root logger on runtime,
# so we init our logger here (on runtime, not on import) to avoid duplicate handlers
self.log = get_code_logger(__name__)

def __call__(self):
self.authenticate()

Expand Down Expand Up @@ -87,14 +89,14 @@ def login(self, force_credentials=False, **kwargs):
if not self.enabled:
return

LOG.info(f"Logging in to {self.url}")
self.log.info(f"Logging in to {self.url}")
new_refresh_token = None

if not force_credentials and self.refresh_token and self.refresh_expires > time.time():
new_refresh_token = self._token_request(ignore_exc=True).get("refresh_token")

if not new_refresh_token:
LOG.info("Please sign in with your credentials.")
self.log.info("Please sign in with your credentials.")
username = input("Username: ")
password = getpass("Password: ")

Expand All @@ -106,7 +108,7 @@ def login(self, force_credentials=False, **kwargs):
self.refresh_token = new_refresh_token
self.save()

LOG.info("Successfully logged in to MLflow. Happy logging!")
self.log.info("Successfully logged in to MLflow. Happy logging!")

def authenticate(self, **kwargs):
"""Check the access token and refresh it if necessary.
Expand Down Expand Up @@ -139,13 +141,13 @@ def authenticate(self, **kwargs):

os.environ["MLFLOW_TRACKING_TOKEN"] = self.access_token

LOG.debug("Access token refreshed.")
self.log.debug("Access token refreshed.")

def save(self, **kwargs):
"""Save the latest refresh token to disk."""

if not self.refresh_token:
LOG.warning("No refresh token to save.")
self.log.warning("No refresh token to save.")
return

config = {
Expand Down Expand Up @@ -194,5 +196,5 @@ def _request(self, path, payload):

return response_json["response"]
except HTTPError as http_err:
LOG.error(f"HTTP error occurred: {http_err}")
self.log.error(f"HTTP error occurred: {http_err}")
raise

0 comments on commit 6e8bfcc

Please sign in to comment.