Skip to content

Commit

Permalink
add ability to continue run in mlflow logs and not create child run
Browse files Browse the repository at this point in the history
* add ability to continue run in mlflow logs and not create child run

add model init logic for weights only and all

bugfix: commented out synchronous arg in MLFLOW logger

fixed overwriting function with hidden property in AIFSMLflowLogger

* Update logging.py

Simplying the if block for setting log_hyperparams

* removed synchronous arg from config, refined code

* Update logged message

* removed synchronous param from AIFSMLflowLogger

* Added plot async param back

* change default setting for on_resume_create_child to False to maintain default behaviour from before this PR

---------

Co-authored-by: [email protected] <[email protected]>
(cherry picked from commit ecmwf-lab/aifs-mono@b856ddd)
  • Loading branch information
gmertes committed Jul 18, 2024
1 parent 4756457 commit 488e903
Showing 1 changed file with 30 additions and 24 deletions.
54 changes: 30 additions & 24 deletions src/anemoi/training/diagnostics/mlflow/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,18 +60,23 @@ def get_mlflow_run_params(config, tracking_uri):
if config.training.run_id or config.training.fork_run_id:
"Either run_id or fork_run_id must be provided to resume a run."

if config.training.run_id:
import mlflow
import mlflow

if config.diagnostics.log.mlflow.authentication and not config.diagnostics.log.mlflow.offline:
TokenAuth(tracking_uri).authenticate()
if config.diagnostics.log.mlflow.authentication and not config.diagnostics.log.mlflow.offline:
TokenAuth(tracking_uri).authenticate()

mlflow_client = mlflow.MlflowClient(tracking_uri)
mlflow_client = mlflow.MlflowClient(tracking_uri)

if config.training.run_id and config.diagnostics.log.mlflow.on_resume_create_child:
parent_run_id = config.training.run_id # parent_run_id
run_name = mlflow_client.get_run(parent_run_id).info.run_name
tags["mlflow.parentRunId"] = parent_run_id
tags["resumedRun"] = "True" # tags can't take boolean values
elif config.training.run_id and not config.diagnostics.log.mlflow.on_resume_create_child:
run_id = config.training.run_id
run_name = mlflow_client.get_run(run_id).info.run_name
mlflow_client.update_run(run_id=run_id, status="RUNNING")
tags["resumedRun"] = "True"
else:
parent_run_id = config.training.fork_run_id
tags["forkedRun"] = "True"
Expand Down Expand Up @@ -282,6 +287,7 @@ def __init__(
run_id: Optional[str] = None,
offline: Optional[bool] = False,
authentication: Optional[bool] = None,
log_hyperparams: Optional[bool] = True,
# artifact_location: Optional[str] = None,
# avoid passing any artifact location otherwise it would mess up the offline logging of artifacts
) -> None:
Expand All @@ -296,6 +302,7 @@ def __init__(

self._resumed = resumed
self._forked = forked
self._flag_log_hparams = log_hyperparams

if rank_zero_only.rank == 0:
enabled = authentication and not offline
Expand Down Expand Up @@ -336,9 +343,8 @@ def log_system_metrics(self) -> None:
self.run_id,
resume_logging=self.run_id is not None,
)
global run_id_to_system_metrics_monitor
run_id_to_system_metrics_monitor = {}
run_id_to_system_metrics_monitor[self.run_id] = system_monitor
self.run_id_to_system_metrics_monitor = {}
self.run_id_to_system_metrics_monitor[self.run_id] = system_monitor
system_monitor.start()

@rank_zero_only
Expand All @@ -353,9 +359,8 @@ def log_terminal_output(self, artifact_save_dir="") -> None:
self.experiment,
self.run_id,
)
global run_id_to_log_monitor
run_id_to_log_monitor = {}
run_id_to_log_monitor[self.run_id] = log_monitor
self.run_id_to_log_monitor = {}
self.run_id_to_log_monitor[self.run_id] = log_monitor
log_monitor.start()

def _clean_params(self, params):
Expand All @@ -373,28 +378,29 @@ def _clean_params(self, params):
@rank_zero_only
def log_hyperparams(self, params: Union[dict[str, Any], Namespace]) -> None:
"""Overwrite the log_hyperparams method to flatten config params using '.'."""
params = _convert_params(params)
params = _flatten_dict(params, delimiter=".") # Flatten dict with '.' to not break API queries
params = self._clean_params(params)
if self._flag_log_hparams:
params = _convert_params(params)
params = _flatten_dict(params, delimiter=".") # Flatten dict with '.' to not break API queries
params = self._clean_params(params)

from mlflow.entities import Param
from mlflow.entities import Param

# Truncate parameter values to 250 characters.
# TODO: MLflow 1.28 allows up to 500 characters: https://github.com/mlflow/mlflow/releases/tag/v1.28.0
params_list = [Param(key=k, value=str(v)[:250]) for k, v in params.items()]
# Truncate parameter values to 250 characters.
# TODO: MLflow 1.28 allows up to 500 characters: https://github.com/mlflow/mlflow/releases/tag/v1.28.0
params_list = [Param(key=k, value=str(v)[:250]) for k, v in params.items()]

for idx in range(0, len(params_list), 100):
self.experiment.log_batch(run_id=self.run_id, params=params_list[idx : idx + 100])
for idx in range(0, len(params_list), 100):
self.experiment.log_batch(run_id=self.run_id, params=params_list[idx : idx + 100])

@rank_zero_only
def finalize(self, status: str = "success") -> None:
# save the last obtained refresh token to disk
self.auth.save()

# finalize logging and system metrics monitor
if run_id_to_system_metrics_monitor:
run_id_to_system_metrics_monitor[self.run_id].finish()
if run_id_to_log_monitor:
run_id_to_log_monitor[self.run_id].finish(status)
if getattr(self, "run_id_to_system_metrics_monitor", None):
self.run_id_to_system_metrics_monitor[self.run_id].finish()
if getattr(self, "run_id_to_log_monitor", None):
self.run_id_to_log_monitor[self.run_id].finish(status)

super().finalize(status)

0 comments on commit 488e903

Please sign in to comment.