Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add IPOPT Optimizer #111

Draft
wants to merge 20 commits into
base: dev
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions CADETProcess/comparison/comparator.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def metrics(self):
return self._metrics

@property
def n_diffference_metrics(self):
def n_difference_metrics(self):
"""int: Number of difference metrics in the Comparator."""
return len(self.metrics)

Expand Down Expand Up @@ -242,11 +242,11 @@ def setup_comparison_figure(
tuple
A tuple of the comparison figure(s) and axes object(s).
"""
if self.n_diffference_metrics == 0:
if self.n_difference_metrics == 0:
return (None, None)

comparison_fig_all, comparison_axs_all = plotting.setup_figure(
n_rows=self.n_diffference_metrics,
n_rows=self.n_difference_metrics,
squeeze=False
)

Expand All @@ -255,7 +255,7 @@ def setup_comparison_figure(

comparison_fig_ind: list[Figure] = []
comparison_axs_ind: list[Axes] = []
for i in range(self.n_diffference_metrics):
for i in range(self.n_difference_metrics):
fig, axs = plt.subplots()
comparison_fig_ind.append(fig)
comparison_axs_ind.append(axs)
Expand Down
118 changes: 107 additions & 11 deletions CADETProcess/comparison/difference.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@

from CADETProcess import CADETProcessError
from CADETProcess.dataStructure import UnsignedInteger
from CADETProcess.solution import SolutionBase, slice_solution
from CADETProcess.solution import SolutionIO, slice_solution
from CADETProcess.metric import MetricBase
from CADETProcess.reference import ReferenceIO, FractionationReference
from .shape import pearson, pearson_offset
from .peaks import find_peaks, find_breakthroughs

Expand All @@ -24,6 +25,7 @@
'Shape',
'PeakHeight', 'PeakPosition',
'BreakthroughHeight', 'BreakthroughPosition',
'FractionationSSE',
]


Expand Down Expand Up @@ -74,7 +76,7 @@ class DifferenceBase(MetricBase):

Parameters
----------
reference : ReferenceIO
reference : ReferenceBase
Reference used for calculating difference metric.
components : {str, list}, optional
Solution components to be considered.
Expand All @@ -97,6 +99,8 @@ class DifferenceBase(MetricBase):
If True, normalize data. The default is False.
"""

_valid_references = ()

def __init__(
self,
reference,
Expand All @@ -106,14 +110,15 @@ def __init__(
start=None,
end=None,
transform=None,
only_transforms_array=True,
resample=True,
smooth=False,
normalize=False):
"""Initialize an instance of DifferenceBase.

Parameters
----------
reference : ReferenceIO
reference : ReferenceBase
Reference used for calculating difference metric.
components : {str, list}, optional
Solution components to be considered.
Expand All @@ -128,6 +133,8 @@ def __init__(
End time of solution slice to be considerd. The default is None.
transform : callable, optional
Function to transform solution. The default is None.
only_transforms_array: bool, optional
If True, only transform np array of solution object. The default is True.
resample : bool, optional
If True, resample data. The default is True.
smooth : bool, optional
Expand All @@ -143,6 +150,7 @@ def __init__(
self.start = start
self.end = end
self.transform = transform
self.only_transforms_array = only_transforms_array
self.resample = resample
self.smooth = smooth
self.normalize = normalize
Expand All @@ -165,8 +173,11 @@ def reference(self):

@reference.setter
def reference(self, reference):
if not isinstance(reference, SolutionBase):
raise TypeError("Expected SolutionBase")
if not isinstance(reference, self._valid_references):
raise TypeError(
f"Invalid reference type: {type(reference)}. "
f"Expected types: {self._valid_references}."
)

self._reference = copy.deepcopy(reference)
if self.resample and not self._reference.is_resampled:
Expand Down Expand Up @@ -221,11 +232,12 @@ def resamples_smoothes_and_normalizes_solution(func):
@wraps(func)
def wrapper(self, solution, *args, **kwargs):
solution = copy.deepcopy(solution)
solution.resample(
self._reference.time[0],
self._reference.time[-1],
len(self._reference.time),
)
if self.resample:
solution.resample(
self._reference.time[0],
self._reference.time[-1],
len(self._reference.time),
)
if self.normalize and not solution.is_normalized:
solution.normalize()
if self.smooth and not solution.is_smoothed:
Expand All @@ -241,7 +253,11 @@ def transforms_solution(func):
def wrapper(self, solution, *args, **kwargs):
if self.transform is not None:
solution = copy.deepcopy(solution)
solution.solution = self.transform(solution.solution)

if self.only_transforms_array:
solution.solution = self.transform(solution.solution)
else:
solution = self.transform(solution)

value = func(self, solution, *args, **kwargs)
return value
Expand Down Expand Up @@ -321,6 +337,8 @@ def calculate_sse(simulation, reference):
class SSE(DifferenceBase):
"""Sum of squared errors (SSE) difference metric."""

_valid_references = (ReferenceIO, SolutionIO)

def _evaluate(self, solution):
sse = calculate_sse(solution.solution, self.reference.solution)

Expand Down Expand Up @@ -348,6 +366,8 @@ def calculate_rmse(simulation, reference):
class RMSE(DifferenceBase):
"""Root mean squared errors (RMSE) difference metric."""

_valid_references = (SolutionIO, ReferenceIO)

def _evaluate(self, solution):
rmse = calculate_rmse(solution.solution, self.reference.solution)

Expand All @@ -357,6 +377,8 @@ def _evaluate(self, solution):
class NRMSE(DifferenceBase):
"""Normalized root mean squared errors (RRMSE) difference metric."""

_valid_references = (SolutionIO, ReferenceIO)

def _evaluate(self, solution):
rmse = calculate_rmse(solution.solution, self.reference.solution)
nrmse = rmse / np.max(self.reference.solution, axis=0)
Expand All @@ -373,6 +395,8 @@ class Norm(DifferenceBase):
The order of the norm.
"""

_valid_references = (SolutionIO, ReferenceIO)

order = UnsignedInteger()

def _evaluate(self, solution):
Expand All @@ -398,6 +422,8 @@ class L2(Norm):
class AbsoluteArea(DifferenceBase):
"""Absolute difference in area difference metric."""

_valid_references = (SolutionIO, ReferenceIO)

def _evaluate(self, solution):
"""np.array: Absolute difference in area compared to reference.

Expand All @@ -418,6 +444,8 @@ def _evaluate(self, solution):
class RelativeArea(DifferenceBase):
"""Relative difference in area difference metric."""

_valid_references = (SolutionIO, ReferenceIO)

def _evaluate(self, solution):
"""np.array: Relative difference in area compared to reference.

Expand Down Expand Up @@ -462,6 +490,8 @@ class Shape(DifferenceBase):

"""

_valid_references = (SolutionIO, ReferenceIO)

@wraps(DifferenceBase.__init__)
def __init__(
self, *args,
Expand Down Expand Up @@ -645,6 +675,8 @@ class PeakHeight(DifferenceBase):
Contains the normalization factors for each peak in each component.
"""

_valid_references = (SolutionIO, ReferenceIO)

@wraps(DifferenceBase.__init__)
def __init__(
self, *args,
Expand Down Expand Up @@ -737,6 +769,8 @@ class PeakPosition(DifferenceBase):
Contains the normalization factors for each peak in each component.
"""

_valid_references = (SolutionIO, ReferenceIO)

@wraps(DifferenceBase.__init__)
def __init__(self, *args, normalize_metrics=True, normalization_factor=None, **kwargs):
"""Initialize PeakPosition object.
Expand Down Expand Up @@ -823,6 +857,8 @@ class BreakthroughHeight(DifferenceBase):

"""

_valid_references = (SolutionIO, ReferenceIO)

@wraps(DifferenceBase.__init__)
def __init__(self, *args, normalize_metrics=True, **kwargs):
"""Initialize BreakthroughHeight metric.
Expand Down Expand Up @@ -874,6 +910,8 @@ def _evaluate(self, solution):
class BreakthroughPosition(DifferenceBase):
"""Absolute difference in breakthrough curve position difference metric."""

_valid_references = (SolutionIO, ReferenceIO)

@wraps(DifferenceBase.__init__)
def __init__(self, *args, normalize_metrics=True, normalization_factor=None, **kwargs):
"""
Expand Down Expand Up @@ -935,3 +973,61 @@ def _evaluate(self, solution):
]

return np.abs(score)


class FractionationSSE(DifferenceBase):
"""Fractionation based score using SSE."""

_valid_references = (FractionationReference)

@wraps(DifferenceBase.__init__)
def __init__(self, *args, normalize_metrics=True, normalization_factor=None, **kwargs):
"""
Initialize the FractionationSSE object.

Parameters
----------
*args :
Positional arguments for DifferenceBase.
normalize_metrics : bool, optional
Whether to normalize the metrics. Default is True.
normalization_factor : float, optional
Factor to use for normalization.
If None, it is set to the maximum of the difference between the reference
breakthrough and the start time, and the difference between the end time and
the reference breakthrough.
**kwargs : dict
Keyword arguments passed to the base class constructor.

"""
super().__init__(*args, resample=False, only_transforms_array=False, **kwargs)

if not isinstance(self.reference, FractionationReference):
raise TypeError("FractionationSSE can only work with FractionationReference")

def transform(solution):
solution = copy.deepcopy(solution)
solution_fractions = [
solution.create_fraction(frac.start, frac.end)
for frac in self.reference.fractions
]

solution.time = np.array([(frac.start + frac.end)/2 for frac in solution_fractions])
solution.solution = np.array([frac.concentration for frac in solution_fractions])

return solution

self.transform = transform

def _evaluate(self, solution):
"""np.array: Difference in breakthrough position (time).

Parameters
----------
solution : SolutionIO
Concentration profile of simulation.

"""
sse = calculate_sse(solution.solution, self.reference.solution)

return sse
14 changes: 14 additions & 0 deletions CADETProcess/dataStructure/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,19 @@ def name(self):


class CachedPropertiesMixin(Structure):
"""
Mixin class for caching properties in a structured object.

This class is designed to be used as a mixin in conjunction with other classes
inheriting from `Structure`. It provides functionality for caching properties and
managing a lock state to control the caching behavior.

Notes
-----
- To prevent the return of outdated state, the cache is cleared whenever the `lock`
state is changed.
"""

_lock = Bool(default=False)

def __init__(self, *args, **kwargs):
Expand All @@ -31,6 +44,7 @@ def __init__(self, *args, **kwargs):

@property
def lock(self):
"""bool: If True, properties are cached. False otherwise."""
return self._lock

@lock.setter
Expand Down
Loading
Loading