Skip to content

Commit

Permalink
[WIP] Restructure Comparator
Browse files Browse the repository at this point in the history
  • Loading branch information
schmoelder committed Mar 23, 2024
1 parent 841f188 commit dffff3f
Showing 1 changed file with 30 additions and 23 deletions.
53 changes: 30 additions & 23 deletions CADETProcess/comparison/comparator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import copy
import importlib
import functools
import warnings

import numpy as np
import matplotlib.pyplot as plt
Expand Down Expand Up @@ -121,20 +122,19 @@ def labels(self):

return labels

@functools.wraps(DifferenceBase.__init__)
def add_difference_metric(
self, difference_metric, reference, solution_path,
self, difference_metric, solution_path, reference,
*args, **kwargs):
"""Add a difference metric to the Comparator.
Parameters
----------
difference_metric : str
Name of the difference metric to be evaluated.
reference : str or SolutionBase
Name of the reference or reference itself.
difference_metric : DifferenceMetricBase
Difference metric to be evaluated.
solution_path : str
Path to the solution in SimulationResults.
reference : str or SolutionBase
Name of the reference or reference itself.
*args, **kwargs
Additional arguments and keyword arguments to be passed to the
difference metric constructor.
Expand All @@ -144,29 +144,36 @@ def add_difference_metric(
CADETProcessError
If the difference metric or reference is unknown.
"""
try:
module = importlib.import_module(
'CADETProcess.comparison.difference'
if isinstance(difference_metric, str):
warnings.warn(
'This method of adding difference metrics is deprecated. '
'Instead, pass an instance of the desired metric class.',
DeprecationWarning, stacklevel=2
)
cls_ = getattr(module, difference_metric)
except KeyError:
raise CADETProcessError("Unknown Metric Type.")
if isinstance(reference, SolutionBase):
reference = reference.name

if isinstance(reference, SolutionBase):
reference = reference.name
if reference not in self.references:
raise CADETProcessError("Unknown Reference.")

if reference not in self.references:
raise CADETProcessError("Unknown Reference.")

reference = self.references[reference]

metric = cls_(reference, *args, **kwargs)
reference = self.references[reference]
try:
module = importlib.import_module(
'CADETProcess.comparison.difference'
)
cls_ = getattr(module, difference_metric)
difference_metric = cls_(reference, *args, **kwargs)
except KeyError:
raise CADETProcessError("Unknown Difference Metric.")
else:
if not isinstance(difference_metric, DifferenceBase):
raise TypeError("Expected DifferenceBase.")

self.solution_paths[metric] = solution_path
self.solution_paths[difference_metric] = solution_path

self._metrics.append(metric)
self._metrics.append(difference_metric)

return metric
return difference_metric

def extract_solution(self, simulation_results, metric):
"""Extract the solution for a given metric from the SimulationResults object.
Expand Down

0 comments on commit dffff3f

Please sign in to comment.