This repository has been archived by the owner on Dec 20, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 15
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add
AnemoiMlflowClient
with auth support (#86)
* feat: anemoi mlflow client with authentication * fix: recursion on anemoi_auth * chore: add tests * chore: changelog
- Loading branch information
Showing
3 changed files
with
95 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts. | ||
# This software is licensed under the terms of the Apache Licence Version 2.0 | ||
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. | ||
# In applying this licence, ECMWF does not waive the privileges and immunities | ||
# granted to it by virtue of its status as an intergovernmental organisation | ||
# nor does it submit to any jurisdiction. | ||
|
||
from __future__ import annotations | ||
|
||
from typing import Any | ||
|
||
from mlflow import MlflowClient | ||
|
||
from anemoi.training.diagnostics.mlflow.auth import TokenAuth | ||
from anemoi.training.diagnostics.mlflow.utils import health_check | ||
|
||
|
||
class AnemoiMlflowClient(MlflowClient): | ||
"""Anemoi extension of the MLflow client with token authentication support.""" | ||
|
||
def __init__( | ||
self, | ||
tracking_uri: str, | ||
*args, | ||
authentication: bool = False, | ||
check_health: bool = True, | ||
**kwargs, | ||
) -> None: | ||
"""Behaves like a normal `mlflow.MlflowClient` but with token authentication injected on every call. | ||
Parameters | ||
---------- | ||
tracking_uri : str | ||
The URI of the MLflow tracking server. | ||
authentication : bool, optional | ||
Enable token authentication, by default False | ||
check_health : bool, optional | ||
Check the health of the MLflow server on init, by default True | ||
*args : Any | ||
Additional arguments to pass to the MLflow client. | ||
**kwargs : Any | ||
Additional keyword arguments to pass to the MLflow client. | ||
""" | ||
self.anemoi_auth = TokenAuth(tracking_uri, enabled=authentication) | ||
if check_health: | ||
super().__getattribute__("anemoi_auth").authenticate() | ||
health_check(tracking_uri) | ||
super().__init__(tracking_uri, *args, **kwargs) | ||
|
||
def __getattribute__(self, name: str) -> Any: | ||
"""Intercept attribute access and inject authentication.""" | ||
attr = super().__getattribute__(name) | ||
if callable(attr) and name != "anemoi_auth": | ||
super().__getattribute__("anemoi_auth").authenticate() | ||
return attr |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts. | ||
# This software is licensed under the terms of the Apache Licence Version 2.0 | ||
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. | ||
# In applying this licence, ECMWF does not waive the privileges and immunities | ||
# granted to it by virtue of its status as an intergovernmental organisation | ||
# nor does it submit to any jurisdiction. | ||
from __future__ import annotations | ||
|
||
from typing import TYPE_CHECKING | ||
|
||
import pytest | ||
|
||
if TYPE_CHECKING: | ||
import pytest_mock | ||
|
||
from anemoi.training.diagnostics.mlflow.client import AnemoiMlflowClient | ||
|
||
|
||
@pytest.fixture(autouse=True) | ||
def mocks(mocker: pytest_mock.MockerFixture) -> None: | ||
mocker.patch("anemoi.training.diagnostics.mlflow.client.TokenAuth") | ||
mocker.patch("anemoi.training.diagnostics.mlflow.client.health_check") | ||
mocker.patch("anemoi.training.diagnostics.mlflow.client.AnemoiMlflowClient.search_experiments") | ||
|
||
|
||
def test_auth_injected() -> None: | ||
client = AnemoiMlflowClient("http://localhost:5000", authentication=True, check_health=False) | ||
client.search_experiments() | ||
client.search_experiments() | ||
|
||
assert client.anemoi_auth.authenticate.call_count == 2 | ||
|
||
|
||
def test_health_check() -> None: | ||
# the internal health check will trigger an authenticate call | ||
client = AnemoiMlflowClient("http://localhost:5000", authentication=True, check_health=True) | ||
|
||
client.anemoi_auth.authenticate.assert_called_once() |