From 98292b12b00d9c36dc515061d4b7fd08f96236a1 Mon Sep 17 00:00:00 2001 From: Gert Mertes <13658335+gmertes@users.noreply.github.com> Date: Tue, 15 Oct 2024 16:31:51 +0100 Subject: [PATCH] Add `AnemoiMlflowClient` with auth support (#86) * feat: anemoi mlflow client with authentication * fix: recursion on anemoi_auth * chore: add tests * chore: changelog --- CHANGELOG.md | 1 + .../training/diagnostics/mlflow/client.py | 56 +++++++++++++++++++ tests/diagnostics/mlflow/test_client.py | 38 +++++++++++++ 3 files changed, 95 insertions(+) create mode 100644 src/anemoi/training/diagnostics/mlflow/client.py create mode 100644 tests/diagnostics/mlflow/test_client.py diff --git a/CHANGELOG.md b/CHANGELOG.md index fec321aa..3b233a77 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -28,6 +28,7 @@ Keep it human-readable, your future self will thank you! - Feature: Add configurable models [#50](https://github.com/ecmwf/anemoi-training/pulls/50) - Feature: Authentication support for mlflow sync - [#51](https://github.com/ecmwf/anemoi-training/pull/51) - Feature: Support training for datasets with missing time steps [#48](https://github.com/ecmwf/anemoi-training/pulls/48) +- Feature: `AnemoiMlflowClient`, an mlflow client with authentication support [#86](https://github.com/ecmwf/anemoi-training/pull/86) - Long Rollout Plots ### Fixed diff --git a/src/anemoi/training/diagnostics/mlflow/client.py b/src/anemoi/training/diagnostics/mlflow/client.py new file mode 100644 index 00000000..5c49f929 --- /dev/null +++ b/src/anemoi/training/diagnostics/mlflow/client.py @@ -0,0 +1,56 @@ +# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts. +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +from __future__ import annotations + +from typing import Any + +from mlflow import MlflowClient + +from anemoi.training.diagnostics.mlflow.auth import TokenAuth +from anemoi.training.diagnostics.mlflow.utils import health_check + + +class AnemoiMlflowClient(MlflowClient): + """Anemoi extension of the MLflow client with token authentication support.""" + + def __init__( + self, + tracking_uri: str, + *args, + authentication: bool = False, + check_health: bool = True, + **kwargs, + ) -> None: + """Behaves like a normal `mlflow.MlflowClient` but with token authentication injected on every call. + + Parameters + ---------- + tracking_uri : str + The URI of the MLflow tracking server. + authentication : bool, optional + Enable token authentication, by default False + check_health : bool, optional + Check the health of the MLflow server on init, by default True + *args : Any + Additional arguments to pass to the MLflow client. + **kwargs : Any + Additional keyword arguments to pass to the MLflow client. + + """ + self.anemoi_auth = TokenAuth(tracking_uri, enabled=authentication) + if check_health: + super().__getattribute__("anemoi_auth").authenticate() + health_check(tracking_uri) + super().__init__(tracking_uri, *args, **kwargs) + + def __getattribute__(self, name: str) -> Any: + """Intercept attribute access and inject authentication.""" + attr = super().__getattribute__(name) + if callable(attr) and name != "anemoi_auth": + super().__getattribute__("anemoi_auth").authenticate() + return attr diff --git a/tests/diagnostics/mlflow/test_client.py b/tests/diagnostics/mlflow/test_client.py new file mode 100644 index 00000000..93c0aeeb --- /dev/null +++ b/tests/diagnostics/mlflow/test_client.py @@ -0,0 +1,38 @@ +# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts. +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pytest + +if TYPE_CHECKING: + import pytest_mock + +from anemoi.training.diagnostics.mlflow.client import AnemoiMlflowClient + + +@pytest.fixture(autouse=True) +def mocks(mocker: pytest_mock.MockerFixture) -> None: + mocker.patch("anemoi.training.diagnostics.mlflow.client.TokenAuth") + mocker.patch("anemoi.training.diagnostics.mlflow.client.health_check") + mocker.patch("anemoi.training.diagnostics.mlflow.client.AnemoiMlflowClient.search_experiments") + + +def test_auth_injected() -> None: + client = AnemoiMlflowClient("http://localhost:5000", authentication=True, check_health=False) + client.search_experiments() + client.search_experiments() + + assert client.anemoi_auth.authenticate.call_count == 2 + + +def test_health_check() -> None: + # the internal health check will trigger an authenticate call + client = AnemoiMlflowClient("http://localhost:5000", authentication=True, check_health=True) + + client.anemoi_auth.authenticate.assert_called_once()