diff --git a/src/anemoi/training/diagnostics/mlflow/logger.py b/src/anemoi/training/diagnostics/mlflow/logger.py index 2098ef50..4bee1e9e 100644 --- a/src/anemoi/training/diagnostics/mlflow/logger.py +++ b/src/anemoi/training/diagnostics/mlflow/logger.py @@ -44,7 +44,7 @@ def get_mlflow_run_params(config, tracking_uri): if config.training.run_id: import mlflow - if not config.diagnostics.log.mlflow.offline: + if config.diagnostics.log.mlflow.authentication and not config.diagnostics.log.mlflow.offline: TokenAuth(tracking_uri).authenticate() mlflow_client = mlflow.MlflowClient(tracking_uri) @@ -262,6 +262,7 @@ def __init__( forked: Optional[bool] = False, run_id: Optional[str] = None, offline: Optional[bool] = False, + authentication: Optional[bool] = None, # artifact_location: Optional[str] = None, # avoid passing any artifact location otherwise it would mess up the offline logging of artifacts ) -> None: @@ -278,8 +279,13 @@ def __init__( self._forked = forked if rank_zero_only.rank == 0: - self.auth = TokenAuth(tracking_uri, enabled=not offline) - LOGGER.info(f"Token authentication {'enabled' if not offline else 'disabled'} for {tracking_uri}") + enabled = authentication and not offline + self.auth = TokenAuth(tracking_uri, enabled=enabled) + + if offline: + LOGGER.info("MLflow is logging offline.") + else: + LOGGER.info(f"MLflow token authentication {'enabled' if enabled else 'disabled'} for {tracking_uri}") super().__init__( experiment_name=experiment_name,