Skip to content

Commit

Permalink
examples: fix typing issues
Browse files Browse the repository at this point in the history
  • Loading branch information
jtsextonMITRE committed Jan 3, 2025
1 parent fc1a8a9 commit 4ddb7b7
Show file tree
Hide file tree
Showing 4 changed files with 2 additions and 46 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ def _log_distance_metrics(distance_metrics_: Dict[str, List[List[float]]]) -> No
del distance_metrics_["image"]
del distance_metrics_["label"]
for metric_name, metric_values_list in distance_metrics_.items():
metric_values = np.array(metric_values_list)
metric_values = np.array(metric_values_list).astype('float64')
post_metrics(metric_name=f"{metric_name}_mean", metric_value=metric_values.mean())
post_metrics(metric_name=f"{metric_name}_median", metric_value=np.median(metric_values))
post_metrics(metric_name=f"{metric_name}_stdev", metric_value=metric_values.std())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def _log_distance_metrics(distance_metrics_: Dict[str, List[List[float]]]) -> No
del distance_metrics_["image"]
del distance_metrics_["label"]
for metric_name, metric_values_list in distance_metrics_.items():
metric_values = np.array(metric_values_list)
metric_values = np.array(metric_values_list).astype('float64')
post_metrics(metric_name=f"{metric_name}_mean", metric_value=metric_values.mean())
post_metrics(metric_name=f"{metric_name}_median", metric_value=np.median(metric_values))
post_metrics(metric_name=f"{metric_name}_stdev", metric_value=metric_values.std())
Expand Down
23 changes: 0 additions & 23 deletions examples/task-plugins/dioptra_custom/fgm_mnist_demo/mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,29 +73,6 @@ def add_model_to_registry(name: str, model_dir: str) -> Optional[ModelVersion]:
return model_version


@pyplugs.register
def get_experiment_name() -> str:
"""Gets the name of the experiment for the current run.
Args:
active_run: The :py:class:`mlflow.ActiveRun` object managing the current run's
state.
Returns:
The name of the experiment.
"""
active_run = mlflow.active_run()

experiment_name: str = (
MlflowClient().get_experiment(active_run.info.experiment_id).name
)
LOGGER.info(
"Obtained experiment name of active run", experiment_name=experiment_name
)

return experiment_name


@pyplugs.register
def prepend_cwd(path: str) -> Path:
ret = Path.cwd() / path
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,27 +81,6 @@ def add_model_to_registry(
return model_version


@pyplugs.register
def get_experiment_name(active_run: MlflowRun) -> str:
"""Gets the name of the experiment for the current run.
Args:
active_run: The :py:class:`mlflow.ActiveRun` object managing the current run's
state.
Returns:
The name of the experiment.
"""
experiment_name: str = (
MlflowClient().get_experiment(active_run.info.experiment_id).name
)
LOGGER.info(
"Obtained experiment name of active run", experiment_name=experiment_name
)

return experiment_name


@pyplugs.register
@require_package("tensorflow", exc_type=TensorflowDependencyError)
def load_tensorflow_keras_classifier(uri: str) -> Sequential:
Expand Down

0 comments on commit 4ddb7b7

Please sign in to comment.