Skip to content

Commit c27cc98

Browse files
authored
Add support for mlflow experiment name in auto3dseg (#7442)
Fixes #7441 This PR enable Auto3DSeg users to manage their runs and experiment more efficiently in MLFlow under arbitrary experiment names, by providing experiment name as an input parameter. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). change). - [x] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. --------- Signed-off-by: Behrooz <[email protected]>
1 parent ec2cc83 commit c27cc98

File tree

3 files changed

+42
-4
lines changed

3 files changed

+42
-4
lines changed

monai/apps/auto3dseg/auto_runner.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ class AutoRunner:
8585
can be skipped based on the analysis on the dataset from Auto3DSeg DataAnalyzer.
8686
mlflow_tracking_uri: a tracking URI for MLflow server which could be local directory or address of the remote
8787
tracking Server; MLflow runs will be recorded locally in algorithms' model folder if the value is None.
88+
mlflow_experiment_name: the name of the experiment in MLflow server.
8889
kwargs: image writing parameters for the ensemble inference. The kwargs format follows the SaveImage
8990
transform. For more information, check https://docs.monai.io/en/stable/transforms.html#saveimage.
9091
@@ -212,6 +213,7 @@ def __init__(
212213
templates_path_or_url: str | None = None,
213214
allow_skip: bool = True,
214215
mlflow_tracking_uri: str | None = None,
216+
mlflow_experiment_name: str | None = None,
215217
**kwargs: Any,
216218
):
217219
if input is None and os.path.isfile(os.path.join(os.path.abspath(work_dir), "input.yaml")):
@@ -253,6 +255,7 @@ def __init__(
253255
self.hpo = hpo and has_nni
254256
self.hpo_backend = hpo_backend
255257
self.mlflow_tracking_uri = mlflow_tracking_uri
258+
self.mlflow_experiment_name = mlflow_experiment_name
256259
self.kwargs = deepcopy(kwargs)
257260

258261
# parse input config for AutoRunner param overrides
@@ -268,7 +271,13 @@ def __init__(
268271
if param in self.data_src_cfg and isinstance(self.data_src_cfg[param], bool):
269272
setattr(self, param, self.data_src_cfg[param]) # e.g. self.analyze = self.data_src_cfg["analyze"]
270273

271-
for param in ["algos", "hpo_backend", "templates_path_or_url", "mlflow_tracking_uri"]: # override from config
274+
for param in [
275+
"algos",
276+
"hpo_backend",
277+
"templates_path_or_url",
278+
"mlflow_tracking_uri",
279+
"mlflow_experiment_name",
280+
]: # override from config
272281
if param in self.data_src_cfg:
273282
setattr(self, param, self.data_src_cfg[param]) # e.g. self.algos = self.data_src_cfg["algos"]
274283

@@ -813,6 +822,7 @@ def run(self):
813822
data_stats_filename=self.datastats_filename,
814823
data_src_cfg_name=self.data_src_cfg_name,
815824
mlflow_tracking_uri=self.mlflow_tracking_uri,
825+
mlflow_experiment_name=self.mlflow_experiment_name,
816826
)
817827

818828
if self.gpu_customization:

monai/apps/auto3dseg/bundle_gen.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,8 @@ def __init__(self, template_path: PathLike):
8585
self.template_path = template_path
8686
self.data_stats_files = ""
8787
self.data_list_file = ""
88-
self.mlflow_tracking_uri = None
88+
self.mlflow_tracking_uri: str | None = None
89+
self.mlflow_experiment_name: str | None = None
8990
self.output_path = ""
9091
self.name = ""
9192
self.best_metric = None
@@ -139,7 +140,16 @@ def set_mlflow_tracking_uri(self, mlflow_tracking_uri: str | None) -> None:
139140
the remote tracking Server; MLflow runs will be recorded locally in algorithms' model folder if
140141
the value is None.
141142
"""
142-
self.mlflow_tracking_uri = mlflow_tracking_uri # type: ignore
143+
self.mlflow_tracking_uri = mlflow_tracking_uri
144+
145+
def set_mlflow_experiment_name(self, mlflow_experiment_name: str | None) -> None:
146+
"""
147+
Set the experiment name for MLflow server
148+
149+
Args:
150+
mlflow_experiment_name: a string to specify the experiment name for MLflow server.
151+
"""
152+
self.mlflow_experiment_name = mlflow_experiment_name
143153

144154
def fill_template_config(self, data_stats_filename: str, algo_path: str, **kwargs: Any) -> dict:
145155
"""
@@ -447,6 +457,7 @@ class BundleGen(AlgoGen):
447457
mlflow_tracking_uri: a tracking URI for MLflow server which could be local directory or address of
448458
the remote tracking Server; MLflow runs will be recorded locally in algorithms' model folder if
449459
the value is None.
460+
mlfow_experiment_name: a string to specify the experiment name for MLflow server.
450461
.. code-block:: bash
451462
452463
python -m monai.apps.auto3dseg BundleGen generate --data_stats_filename="../algorithms/datastats.yaml"
@@ -460,6 +471,7 @@ def __init__(
460471
data_stats_filename: str | None = None,
461472
data_src_cfg_name: str | None = None,
462473
mlflow_tracking_uri: str | None = None,
474+
mlflow_experiment_name: str | None = None,
463475
):
464476
if algos is None or isinstance(algos, (list, tuple, str)):
465477
if templates_path_or_url is None:
@@ -513,6 +525,7 @@ def __init__(
513525
self.data_stats_filename = data_stats_filename
514526
self.data_src_cfg_name = data_src_cfg_name
515527
self.mlflow_tracking_uri = mlflow_tracking_uri
528+
self.mlflow_experiment_name = mlflow_experiment_name
516529
self.history: list[dict] = []
517530

518531
def set_data_stats(self, data_stats_filename: str) -> None:
@@ -552,10 +565,23 @@ def set_mlflow_tracking_uri(self, mlflow_tracking_uri):
552565
"""
553566
self.mlflow_tracking_uri = mlflow_tracking_uri
554567

568+
def set_mlflow_experiment_name(self, mlflow_experiment_name):
569+
"""
570+
Set the experiment name for MLflow server
571+
572+
Args:
573+
mlflow_experiment_name: a string to specify the experiment name for MLflow server.
574+
"""
575+
self.mlflow_experiment_name = mlflow_experiment_name
576+
555577
def get_mlflow_tracking_uri(self):
556578
"""Get the tracking URI for MLflow server"""
557579
return self.mlflow_tracking_uri
558580

581+
def get_mlflow_experiment_name(self):
582+
"""Get the experiment name for MLflow server"""
583+
return self.mlflow_experiment_name
584+
559585
def get_history(self) -> list:
560586
"""Get the history of the bundleAlgo object with their names/identifiers"""
561587
return self.history
@@ -608,10 +634,12 @@ def generate(
608634
data_stats = self.get_data_stats()
609635
data_src_cfg = self.get_data_src()
610636
mlflow_tracking_uri = self.get_mlflow_tracking_uri()
637+
mlflow_experiment_name = self.get_mlflow_experiment_name()
611638
gen_algo = deepcopy(algo)
612639
gen_algo.set_data_stats(data_stats)
613640
gen_algo.set_data_source(data_src_cfg)
614641
gen_algo.set_mlflow_tracking_uri(mlflow_tracking_uri)
642+
gen_algo.set_mlflow_experiment_name(mlflow_experiment_name)
615643
name = f"{gen_algo.name}_{f_id}"
616644

617645
if allow_skip:

monai/apps/auto3dseg/ensemble_builder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -464,7 +464,7 @@ def set_ensemble_method(self, ensemble_method_name: str = "AlgoEnsembleBestByFol
464464
ensemble_method_name, supported=["AlgoEnsembleBestN", "AlgoEnsembleBestByFold"]
465465
)
466466
if self.ensemble_method_name == "AlgoEnsembleBestN":
467-
n_best = kwargs.pop("n_best", False) or 2
467+
n_best = kwargs.pop("n_best", 2)
468468
self.ensemble_method = AlgoEnsembleBestN(n_best=n_best)
469469
elif self.ensemble_method_name == "AlgoEnsembleBestByFold":
470470
self.ensemble_method = AlgoEnsembleBestByFold(n_fold=self.num_fold) # type: ignore

0 commit comments

Comments
 (0)