Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MLflow token authentication #2

Merged
merged 41 commits into from
Jul 18, 2024
Merged
Show file tree
Hide file tree
Changes from 33 commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
3531281
Update dependencies
gmertes Jun 25, 2024
1c95835
Add mlflow subpackage
gmertes Jun 25, 2024
624fabb
Add auth module
gmertes Jun 25, 2024
fbe8d0a
Use unix time
gmertes Jun 26, 2024
a05a374
Store new refresh token in memory
gmertes Jun 26, 2024
7dec83c
Use anemoi.utils.config
gmertes Jun 26, 2024
683069b
Add authenticate fn
gmertes Jun 26, 2024
7e5bb01
Load config on init
gmertes Jun 26, 2024
70bf8c3
Refactor login logic
gmertes Jun 26, 2024
224897d
Add force credentials arg
gmertes Jun 27, 2024
460d598
Check refresh token expiry on auth
gmertes Jun 27, 2024
de93951
Handle error status
gmertes Jun 27, 2024
aecf26a
Retry login with credentials if token fails
gmertes Jun 27, 2024
ff3206b
Add mlflow command
gmertes Jun 27, 2024
2140a56
Assert url
gmertes Jun 27, 2024
868f269
Header
gmertes Jun 28, 2024
c5acb4f
Add docstrings
gmertes Jun 28, 2024
c4d9a4e
Fix logger import
gmertes Jun 28, 2024
d39541f
Make refresh token a property, update expiry in setter
gmertes Jun 28, 2024
f7fb94a
Refactor token request, update refresh token on auth
gmertes Jun 28, 2024
f4b781c
Update mlflow logger to latest aifs-mono version
gmertes Jul 2, 2024
6e8bfcc
Use code logger
gmertes Jul 2, 2024
d7b2abc
Return a global auth instance with get_auth
gmertes Jul 2, 2024
4a7509c
Add enabled decorator
gmertes Jul 2, 2024
131b138
Log refresh token expiry date
gmertes Jul 2, 2024
c4d8b2e
Simplify expiry calculation
gmertes Jul 2, 2024
6edf62d
Remove assert url
gmertes Jul 2, 2024
fb2c2e8
Remove global auth facility
gmertes Jul 3, 2024
b3a1ea8
Integrate TokenAuth into AIFSMLflowLogger
gmertes Jul 3, 2024
04f90d5
Only initialise on rank 0
gmertes Jul 5, 2024
47d7e69
Log time it takes to refresh token
gmertes Jul 5, 2024
9e036a0
Add tests
gmertes Jul 5, 2024
b0beb3b
Add authentication config entry
gmertes Jul 8, 2024
0fb6067
Comment
gmertes Jul 10, 2024
4ff1339
Add health check
gmertes Jul 12, 2024
b3ae738
Make refresh token expiry a constant
gmertes Jul 12, 2024
290e48f
Make --url required, in absence of a config
gmertes Jul 16, 2024
edcd1e1
Rename AIFSMLflowLogger -> AnemoiMLflowLogger
gmertes Jul 17, 2024
4756457
API update
gmertes Jul 17, 2024
60a2928
add ability to continue run in mlflow logs and not create child run
gmertes Jul 18, 2024
808bd3d
Add mlflow sync command placeholder, update logging
gmertes Jul 18, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,9 @@ dynamic = [
"version",
]
dependencies = [
"anemoi-datasets[data]>=0.1",
"anemoi-models @ git+https://github.com/ecmwf/anemoi-models.git",
"anemoi-utils[provenance]>=0.1.3",
"anemoi-datasets>=0.1",
"anemoi-models",
gmertes marked this conversation as resolved.
Show resolved Hide resolved
"anemoi-utils[provenance]>=0.3.7",
"einops>=0.6.1",
"hydra-core>=1.3",
"matplotlib>=3.7.1",
Expand All @@ -76,6 +76,8 @@ optional-dependencies.all = [
optional-dependencies.dev = [
"nbsphinx",
"pandoc",
"pytest",
"pytest-mock",
"sphinx",
"sphinx-argparse",
"sphinx-rtd-theme",
Expand All @@ -89,6 +91,11 @@ optional-dependencies.docs = [
"sphinx-rtd-theme",
]

optional-dependencies.tests = [
"pytest",
"pytest-mock",
]
gmertes marked this conversation as resolved.
Show resolved Hide resolved

urls.Documentation = "https://anemoi-training.readthedocs.io/"
urls.Homepage = "https://github.com/ecmwf/anemoi-training/"
urls.Issues = "https://github.com/ecmwf/anemoi-training/issues"
Expand Down
40 changes: 40 additions & 0 deletions src/anemoi/training/commands/mlflow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.


from anemoi.training.diagnostics.mlflow.auth import TokenAuth

from . import Command


class MlFlow(Command):
"""Commands to interact with MLflow."""

def add_arguments(self, command_parser):
subparsers = command_parser.add_subparsers(dest="subcommand", required=True)

login = subparsers.add_parser(
"login",
help="Log in and acquire a token from keycloak.",
)
login.add_argument(
"--url",
help="The URL of the authentication server",
)
login.add_argument(
"--force-credentials",
"-f",
action="store_true",
help="Force a credential login even if a token is available.",
)

def run(self, args):
if args.subcommand == "login":
TokenAuth(url=args.url).login(force_credentials=args.force_credentials)


command = MlFlow
gmertes marked this conversation as resolved.
Show resolved Hide resolved
2 changes: 1 addition & 1 deletion src/anemoi/training/diagnostics/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,7 +540,7 @@ def tracker_metadata(self, trainer):
elif self.config.diagnostics.log.mlflow.enabled:
self._tracker_name = "mlflow"

from anemoi.training.diagnostics.mlflow_logger import AIFSMLflowLogger
from anemoi.training.diagnostics.mlflow.logger import AIFSMLflowLogger
gmertes marked this conversation as resolved.
Show resolved Hide resolved

mlflow_logger = next(logger for logger in trainer.loggers if isinstance(logger, AIFSMLflowLogger))
run_id = mlflow_logger.run_id
Expand Down
Empty file.
211 changes: 211 additions & 0 deletions src/anemoi/training/diagnostics/mlflow/auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

import json
import os
import time
from datetime import datetime
from functools import wraps
from getpass import getpass

import requests
from anemoi.utils.config import load_config
from anemoi.utils.config import save_config
from anemoi.utils.timer import Timer
from requests.exceptions import HTTPError

from anemoi.training.utils.logger import get_code_logger


class TokenAuth:
"""Manage authentication with a keycloak token server."""

def __init__(
self,
url,
refresh_expire_days=29,
gmertes marked this conversation as resolved.
Show resolved Hide resolved
enabled=True,
):
"""Parameters
----------
url : str
URL of the authentication server.
refresh_expire_days : int, optional
Number of days before the refresh token expires, by default 29
enabled : bool, optional
Set this to False to turn off authentication, by default True
"""

self.url = url
gmertes marked this conversation as resolved.
Show resolved Hide resolved
self.refresh_expire_days = refresh_expire_days
self._enabled = enabled

self.config_file = "mlflow-token.json"
gmertes marked this conversation as resolved.
Show resolved Hide resolved
config = load_config(self.config_file)

self._refresh_token = config.get("refresh_token")
self.refresh_expires = config.get("refresh_expires", 0)
self.access_token = None
self.access_expires = 0

# the command line tool adds a default handler to the root logger on runtime,
# so we init our logger here (on runtime, not on import) to avoid duplicate handlers
self.log = get_code_logger(__name__)
gmertes marked this conversation as resolved.
Show resolved Hide resolved

def __call__(self):
self.authenticate()

@property
def refresh_token(self):
return self._refresh_token

@refresh_token.setter
def refresh_token(self, value):
self._refresh_token = value
self.refresh_expires = time.time() + (self.refresh_expire_days * 86400)
gmertes marked this conversation as resolved.
Show resolved Hide resolved

def enabled(fn):
"""Decorator to call or ignore a function based on the `enabled` flag."""

@wraps(fn)
def _wrapper(self, *args, **kwargs):
if self._enabled:
return fn(self, *args, **kwargs)
return

return _wrapper

@enabled
def login(self, force_credentials=False, **kwargs):
"""Acquire a new refresh token and save it to disk.

If an existing valid refresh token is already on disk it will be used.
If not, or the token has expired, the user will be prompted for credentials.

This function should be called once, interactively, right before starting a training run.

Parameters
----------
force_credentials : bool, optional
Force a username/password prompt even if a refreh token is available, by default False.

Raises
------
RuntimeError
A new refresh token could not be acquired.
"""

self.log.info(f"Logging in to {self.url}")
new_refresh_token = None

if not force_credentials and self.refresh_token and self.refresh_expires > time.time():
new_refresh_token = self._token_request(ignore_exc=True).get("refresh_token")

if not new_refresh_token:
self.log.info("Please sign in with your credentials.")
username = input("Username: ")
password = getpass("Password: ")

new_refresh_token = self._token_request(username=username, password=password).get("refresh_token")

if not new_refresh_token:
raise RuntimeError("Failed to log in. Please try again.")

self.refresh_token = new_refresh_token
self.save()

self.log.info("Successfully logged in to MLflow. Happy logging!")

@enabled
def authenticate(self, **kwargs):
"""Check the access token and refresh it if necessary.

The access token is stored in memory and in the environment variable `MLFLOW_TRACKING_TOKEN`.
If the access token is still valid, this function does nothing.

This function should be called before every MLflow API request.

Raises
------
RuntimeError
No refresh token is available or the token request failed.
"""

if self.access_expires > time.time():
return

if not self.refresh_token or self.refresh_expires < time.time():
raise RuntimeError("You are not logged in to MLflow. Please log in first.")

with Timer("Access token refreshed", self.log):
response = self._token_request()

self.access_token = response.get("access_token")
self.access_expires = time.time() + (response.get("expires_in") * 0.7) # bit of buffer
self.refresh_token = response.get("refresh_token")

os.environ["MLFLOW_TRACKING_TOKEN"] = self.access_token

@enabled
def save(self, **kwargs):
"""Save the latest refresh token to disk."""

if not self.refresh_token:
self.log.warning("No refresh token to save.")
return

config = {
"refresh_token": self.refresh_token,
"refresh_expires": self.refresh_expires,
}
save_config(self.config_file, config)

expire_date = datetime.fromtimestamp(self.refresh_expires)
self.log.info("Your MLflow token is valid until %s UTC", expire_date.strftime("%Y-%m-%d %H:%M:%S"))

def _token_request(self, username=None, password=None, ignore_exc=False):
if username is not None and password is not None:
path = "newtoken"
payload = {"username": username, "password": password}
else:
path = "refreshtoken"
payload = {"refresh_token": self.refresh_token}

try:
response = self._request(path, payload)
except Exception as err:
if ignore_exc:
return {}
raise err

return response

def _request(self, path, payload):

headers = {
"Content-Type": "application/x-www-form-urlencoded",
}

try:
response = requests.post(f"{self.url}/{path}", headers=headers, json=payload)
response.raise_for_status()
response_json = response.json()

if response_json.get("status", "") != "OK":
# TODO: there's a bug in the API that returns the error response as a string instead of a json object.
# Remove this when the API is fixed.
if isinstance(response_json["response"], str):
error = json.loads(response_json["response"])
else:
error = response_json["response"]
error_description = error.get("error_description", "Error acquiring token.")
raise RuntimeError(error_description)

return response_json["response"]
except HTTPError as http_err:
self.log.error(f"HTTP error occurred: {http_err}")
raise
Loading
Loading