Skip to content
This repository has been archived by the owner on Dec 20, 2024. It is now read-only.

Commit

Permalink
Add AnemoiMlflowClient with auth support (#86)
Browse files Browse the repository at this point in the history
* feat: anemoi mlflow client with authentication

* fix: recursion on anemoi_auth

* chore: add tests

* chore: changelog
  • Loading branch information
gmertes authored Oct 15, 2024
1 parent 39c309d commit 98292b1
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ Keep it human-readable, your future self will thank you!
- Feature: Add configurable models [#50](https://github.com/ecmwf/anemoi-training/pulls/50)
- Feature: Authentication support for mlflow sync - [#51](https://github.com/ecmwf/anemoi-training/pull/51)
- Feature: Support training for datasets with missing time steps [#48](https://github.com/ecmwf/anemoi-training/pulls/48)
- Feature: `AnemoiMlflowClient`, an mlflow client with authentication support [#86](https://github.com/ecmwf/anemoi-training/pull/86)
- Long Rollout Plots

### Fixed
Expand Down
56 changes: 56 additions & 0 deletions src/anemoi/training/diagnostics/mlflow/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

from __future__ import annotations

from typing import Any

from mlflow import MlflowClient

from anemoi.training.diagnostics.mlflow.auth import TokenAuth
from anemoi.training.diagnostics.mlflow.utils import health_check


class AnemoiMlflowClient(MlflowClient):
"""Anemoi extension of the MLflow client with token authentication support."""

def __init__(
self,
tracking_uri: str,
*args,
authentication: bool = False,
check_health: bool = True,
**kwargs,
) -> None:
"""Behaves like a normal `mlflow.MlflowClient` but with token authentication injected on every call.
Parameters
----------
tracking_uri : str
The URI of the MLflow tracking server.
authentication : bool, optional
Enable token authentication, by default False
check_health : bool, optional
Check the health of the MLflow server on init, by default True
*args : Any
Additional arguments to pass to the MLflow client.
**kwargs : Any
Additional keyword arguments to pass to the MLflow client.
"""
self.anemoi_auth = TokenAuth(tracking_uri, enabled=authentication)
if check_health:
super().__getattribute__("anemoi_auth").authenticate()
health_check(tracking_uri)
super().__init__(tracking_uri, *args, **kwargs)

def __getattribute__(self, name: str) -> Any:
"""Intercept attribute access and inject authentication."""
attr = super().__getattribute__(name)
if callable(attr) and name != "anemoi_auth":
super().__getattribute__("anemoi_auth").authenticate()
return attr
38 changes: 38 additions & 0 deletions tests/diagnostics/mlflow/test_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.
from __future__ import annotations

from typing import TYPE_CHECKING

import pytest

if TYPE_CHECKING:
import pytest_mock

from anemoi.training.diagnostics.mlflow.client import AnemoiMlflowClient


@pytest.fixture(autouse=True)
def mocks(mocker: pytest_mock.MockerFixture) -> None:
mocker.patch("anemoi.training.diagnostics.mlflow.client.TokenAuth")
mocker.patch("anemoi.training.diagnostics.mlflow.client.health_check")
mocker.patch("anemoi.training.diagnostics.mlflow.client.AnemoiMlflowClient.search_experiments")


def test_auth_injected() -> None:
client = AnemoiMlflowClient("http://localhost:5000", authentication=True, check_health=False)
client.search_experiments()
client.search_experiments()

assert client.anemoi_auth.authenticate.call_count == 2


def test_health_check() -> None:
# the internal health check will trigger an authenticate call
client = AnemoiMlflowClient("http://localhost:5000", authentication=True, check_health=True)

client.anemoi_auth.authenticate.assert_called_once()

0 comments on commit 98292b1

Please sign in to comment.