diff --git a/CADETProcess/comparison/comparator.py b/CADETProcess/comparison/comparator.py index 4c0a8752f..bef01c989 100644 --- a/CADETProcess/comparison/comparator.py +++ b/CADETProcess/comparison/comparator.py @@ -1,6 +1,7 @@ import copy import importlib import functools +import warnings import numpy as np import matplotlib.pyplot as plt @@ -121,20 +122,19 @@ def labels(self): return labels - @functools.wraps(DifferenceBase.__init__) def add_difference_metric( - self, difference_metric, reference, solution_path, + self, difference_metric, solution_path, reference, *args, **kwargs): """Add a difference metric to the Comparator. Parameters ---------- - difference_metric : str - Name of the difference metric to be evaluated. - reference : str or SolutionBase - Name of the reference or reference itself. + difference_metric : DifferenceMetricBase + Difference metric to be evaluated. solution_path : str Path to the solution in SimulationResults. + reference : str or SolutionBase + Name of the reference or reference itself. *args, **kwargs Additional arguments and keyword arguments to be passed to the difference metric constructor. @@ -144,29 +144,36 @@ def add_difference_metric( CADETProcessError If the difference metric or reference is unknown. """ - try: - module = importlib.import_module( - 'CADETProcess.comparison.difference' + if isinstance(difference_metric, str): + warnings.warn( + 'This method of adding difference metrics is deprecated. ' + 'Instead, pass an instance of the desired metric class.', + DeprecationWarning, stacklevel=2 ) - cls_ = getattr(module, difference_metric) - except KeyError: - raise CADETProcessError("Unknown Metric Type.") + if isinstance(reference, SolutionBase): + reference = reference.name - if isinstance(reference, SolutionBase): - reference = reference.name + if reference not in self.references: + raise CADETProcessError("Unknown Reference.") - if reference not in self.references: - raise CADETProcessError("Unknown Reference.") - - reference = self.references[reference] - - metric = cls_(reference, *args, **kwargs) + reference = self.references[reference] + try: + module = importlib.import_module( + 'CADETProcess.comparison.difference' + ) + cls_ = getattr(module, difference_metric) + difference_metric = cls_(reference, *args, **kwargs) + except KeyError: + raise CADETProcessError("Unknown Difference Metric.") + else: + if not isinstance(difference_metric, DifferenceBase): + raise TypeError("Expected DifferenceBase.") - self.solution_paths[metric] = solution_path + self.solution_paths[difference_metric] = solution_path - self._metrics.append(metric) + self._metrics.append(difference_metric) - return metric + return difference_metric def extract_solution(self, simulation_results, metric): """Extract the solution for a given metric from the SimulationResults object.