From 8f0faf63109e361b75614ca6ed05d48d930e22d0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20Schm=C3=B6lder?= Date: Fri, 19 Jan 2024 14:47:44 +0100 Subject: [PATCH] Add FractionationSSE --- CADETProcess/comparison/difference.py | 61 ++++++++++++++++++++++++++- tests/test_difference.py | 48 +++++++++++++++++++++ 2 files changed, 108 insertions(+), 1 deletion(-) diff --git a/CADETProcess/comparison/difference.py b/CADETProcess/comparison/difference.py index e7e055d9..f19feea0 100644 --- a/CADETProcess/comparison/difference.py +++ b/CADETProcess/comparison/difference.py @@ -11,7 +11,7 @@ from CADETProcess.dataStructure import UnsignedInteger from CADETProcess.solution import SolutionIO, slice_solution from CADETProcess.metric import MetricBase -from CADETProcess.reference import ReferenceIO +from CADETProcess.reference import ReferenceIO, FractionationReference from .shape import pearson, pearson_offset from .peaks import find_peaks, find_breakthroughs @@ -25,6 +25,7 @@ 'Shape', 'PeakHeight', 'PeakPosition', 'BreakthroughHeight', 'BreakthroughPosition', + 'FractionationSSE', ] @@ -972,3 +973,61 @@ def _evaluate(self, solution): ] return np.abs(score) + + +class FractionationSSE(DifferenceBase): + """Fractionation based score using SSE.""" + + _valid_references = (FractionationReference) + + @wraps(DifferenceBase.__init__) + def __init__(self, *args, normalize_metrics=True, normalization_factor=None, **kwargs): + """ + Initialize the FractionationSSE object. + + Parameters + ---------- + *args : + Positional arguments for DifferenceBase. + normalize_metrics : bool, optional + Whether to normalize the metrics. Default is True. + normalization_factor : float, optional + Factor to use for normalization. + If None, it is set to the maximum of the difference between the reference + breakthrough and the start time, and the difference between the end time and + the reference breakthrough. + **kwargs : dict + Keyword arguments passed to the base class constructor. + + """ + super().__init__(*args, resample=False, only_transforms_array=False, **kwargs) + + if not isinstance(self.reference, FractionationReference): + raise TypeError("FractionationSSE can only work with FractionationReference") + + def transform(solution): + solution = copy.deepcopy(solution) + solution_fractions = [ + solution.create_fraction(frac.start, frac.end) + for frac in self.reference.fractions + ] + + solution.time = np.array([(frac.start + frac.end)/2 for frac in solution_fractions]) + solution.solution = np.array([frac.concentration for frac in solution_fractions]) + + return solution + + self.transform = transform + + def _evaluate(self, solution): + """np.array: Difference in breakthrough position (time). + + Parameters + ---------- + solution : SolutionIO + Concentration profile of simulation. + + """ + sse = calculate_sse(solution.solution, self.reference.solution) + + return sse diff --git a/tests/test_difference.py b/tests/test_difference.py index fe7e79ae..6a61b6d6 100644 --- a/tests/test_difference.py +++ b/tests/test_difference.py @@ -6,6 +6,7 @@ from CADETProcess import CADETProcessError from CADETProcess.processModel import ComponentSystem from CADETProcess.reference import ReferenceIO +from CADETProcess.solution import SolutionIO comp_2 = ComponentSystem(['A', 'B']) @@ -34,6 +35,8 @@ solution_2_gaussian_different_height[:, 1] = stats.norm.pdf(time, mu_0, sigma_0) solution_2_gaussian_different_height[:, 0] = stats.norm.pdf(time, mu_2, sigma_2) +q_const = np.ones(time.shape) + from CADETProcess.comparison import SSE class TestSSE(unittest.TestCase): @@ -387,5 +390,50 @@ def test_metric(self): metrics = difference.evaluate(self.reference) +from CADETProcess.fractionation import Fraction +from CADETProcess.reference import FractionationReference +from CADETProcess.comparison import FractionationSSE + + +class TestFractionation(unittest.TestCase): + def __init__(self, methodName='runTest'): + super().__init__(methodName) + + def setUp(self): + fraction_1 = Fraction( + start=15, + end=30, + mass=[0.49865015, 0.02274985], + volume=15, + ) + fraction_2 = Fraction( + start=30, + end=45, + mass=[0.49865015, 0.81859462], + volume=15, + ) + self.fractions = [fraction_1, fraction_2] + + component_system = ComponentSystem(['A', 'B']) + self.reference = FractionationReference( + 'fractions', [fraction_1, fraction_2], + component_system=component_system + ) + + self.solution = SolutionIO( + 'simple', comp_2, time, solution_2_gaussian, flow_rate=q_const + ) + + def test_metric(self): + # Compare with itself + difference = FractionationSSE( + self.reference, + components=['A'], + ) + metrics_expected = [1.30315857e-19] + metrics = difference.evaluate(self.solution) + np.testing.assert_almost_equal(metrics, metrics_expected) + + if __name__ == '__main__': unittest.main()