Skip to content

Commit 1f6db1f

Browse files
authored
Refactor pre/post processing implementation (PyDMD#554)
Refactor pre/post processing implementation
1 parent 9984eda commit 1f6db1f

File tree

10 files changed

+297
-258
lines changed

10 files changed

+297
-258
lines changed

pydmd/bopdmd.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1369,7 +1369,7 @@ def _initialize_alpha(self):
13691369
# rank-truncated SVD of Y.
13701370
Y = (ux1 + ux2) / 2
13711371
# Element-wise division by time differences. w/o large T
1372-
Z = (ux2 - ux1) / (self._time[1:] - self._time[:-1])
1372+
Z = (ux2 - ux1) / (self._time[1:] - self._time[:-1])
13731373
U, s, V = compute_svd(Y, self._svd_rank)
13741374
S = np.diag(s)
13751375

@@ -1712,16 +1712,16 @@ def plot_eig_uq(
17121712

17131713
if flip_axes:
17141714
eigs = self.eigs.imag + 1j * self.eigs.real
1715-
plt.xlabel("$Im(\omega)$")
1716-
plt.ylabel("$Re(\omega)$")
1715+
plt.xlabel(r"$Im(\omega)$")
1716+
plt.ylabel(r"$Re(\omega)$")
17171717

17181718
if eigs_true is not None:
17191719
eigs_true = eigs_true.imag + 1j * eigs_true.real
17201720

17211721
else:
17221722
eigs = self.eigs
1723-
plt.xlabel("$Re(\omega)$")
1724-
plt.ylabel("$Im(\omega)$")
1723+
plt.xlabel(r"$Re(\omega)$")
1724+
plt.ylabel(r"$Im(\omega)$")
17251725

17261726
for e, std in zip(eigs, self.eigenvalues_std):
17271727
# Plot 2 standard deviations.

pydmd/preprocessing/hankel.py

Lines changed: 52 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -2,22 +2,32 @@
22
Hankel pre-processing.
33
"""
44

5-
from functools import partial
6-
from typing import Dict, List, Tuple, Union
5+
from __future__ import annotations
6+
7+
import sys
8+
from typing import Tuple, Union
79

810
import numpy as np
911

1012
from pydmd.dmdbase import DMDBase
11-
from pydmd.preprocessing.pre_post_processing import PrePostProcessingDMD
13+
from pydmd.preprocessing.pre_post_processing import (
14+
PrePostProcessing,
15+
PrePostProcessingDMD,
16+
)
1217
from pydmd.utils import pseudo_hankel_matrix
1318

14-
_reconstruction_method_type = Union[str, np.ndarray, List, Tuple]
19+
if sys.version_info >= (3, 12):
20+
from typing import override
21+
else:
22+
from typing_extensions import override
23+
24+
_ReconstructionMethodType = Union[str, np.ndarray, list, tuple]
1525

1626

1727
def hankel_preprocessing(
1828
dmd: DMDBase,
1929
d: int,
20-
reconstruction_method: _reconstruction_method_type = "first",
30+
reconstruction_method: _ReconstructionMethodType = "first",
2131
):
2232
"""
2333
Hankel pre-processing.
@@ -26,19 +36,13 @@ def hankel_preprocessing(
2636
:param d: Hankel matrix rank
2737
:param reconstruction_method: Reconstruction method.
2838
"""
29-
return PrePostProcessingDMD(
30-
dmd,
31-
partial(_preprocessing, d=d),
32-
partial(
33-
_hankel_post_processing,
34-
d=d,
35-
reconstruction_method=reconstruction_method,
36-
),
37-
)
39+
if isinstance(reconstruction_method, (tuple, list)):
40+
reconstruction_method = np.asarray(reconstruction_method)
3841

39-
40-
def _preprocessing(_: Dict, X: np.ndarray, d: int, **kwargs):
41-
return (pseudo_hankel_matrix(X, d),) + tuple(kwargs.values())
42+
pre_post_processing = _HankelPrePostProcessing(
43+
d=d, reconstruction_method=reconstruction_method
44+
)
45+
return PrePostProcessingDMD(dmd, pre_post_processing)
4246

4347

4448
def _reconstructions(rec: np.ndarray, d: int):
@@ -92,24 +96,34 @@ def _first_reconstructions(reconstructions: np.ndarray, d: int) -> np.ndarray:
9296
return reconstructions[..., time_idxes, d_idxes, :].swapaxes(-1, -2)
9397

9498

95-
def _hankel_post_processing(
96-
_: Dict, # No state
97-
X: np.ndarray,
98-
d: int,
99-
reconstruction_method: _reconstruction_method_type,
100-
) -> np.ndarray:
101-
rec = _reconstructions(X, d=d)
102-
rec = np.ma.array(rec, mask=np.isnan(rec))
103-
104-
if reconstruction_method == "first":
105-
result = _first_reconstructions(rec, d=d)
106-
elif reconstruction_method == "mean":
107-
result = np.nanmean(rec, axis=1).T
108-
elif isinstance(reconstruction_method, (np.ndarray, list, tuple)):
109-
result = np.ma.average(rec, axis=1, weights=reconstruction_method).T
110-
else:
111-
raise ValueError(
112-
f"The reconstruction method wasn't recognized: {reconstruction_method}"
113-
)
114-
115-
return result.filled(fill_value=0)
99+
class _HankelPrePostProcessing(PrePostProcessing):
100+
def __init__(
101+
self, *, d: int, reconstruction_method: _ReconstructionMethodType
102+
):
103+
self._d = d
104+
self._reconstruction_method = reconstruction_method
105+
106+
@override
107+
def pre_processing(self, X: np.ndarray) -> Tuple[None, np.ndarray]:
108+
return None, pseudo_hankel_matrix(X, self._d)
109+
110+
@override
111+
def post_processing(
112+
self, pre_processing_output: None, Y: np.ndarray
113+
) -> np.ndarray:
114+
rec = _reconstructions(Y, d=self._d)
115+
rec = np.ma.array(rec, mask=np.isnan(rec))
116+
117+
if isinstance(self._reconstruction_method, str):
118+
if self._reconstruction_method == "first":
119+
result = _first_reconstructions(rec, d=self._d)
120+
elif self._reconstruction_method == "mean":
121+
result = np.nanmean(rec, axis=1).T
122+
elif isinstance(self._reconstruction_method, np.ndarray):
123+
result = np.ma.average(
124+
rec, axis=1, weights=self._reconstruction_method
125+
).T
126+
else:
127+
raise ValueError(f"{self._reconstruction_method=} not recognized")
128+
129+
return result.filled(fill_value=0)

pydmd/preprocessing/pre_post_processing.py

Lines changed: 72 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -2,27 +2,30 @@
22
Pre/post-processing capability for DMD instances.
33
"""
44

5-
from typing import Callable, Dict
5+
from __future__ import annotations
66

7-
from pydmd.dmdbase import DMDBase
7+
from inspect import isroutine
8+
from typing import Any, Dict, Generic, Tuple, TypeVar
89

10+
import numpy as np
911

10-
def _shallow_preprocessing(_: Dict, *args, **kwargs):
11-
return args + tuple(kwargs.values())
12+
from pydmd.dmdbase import DMDBase
1213

14+
# Pre-processing output type
15+
S = TypeVar("S")
1316

14-
def _shallow_postprocessing(_: Dict, *args):
15-
# The first item of args is always the output of dmd.reconstructed_data
16-
return args[0]
1717

18+
class PrePostProcessing(Generic[S]):
19+
def pre_processing(self, X: np.ndarray) -> Tuple[S, np.ndarray]:
20+
return None, X
1821

19-
def _tuplify(value):
20-
if isinstance(value, tuple):
21-
return value
22-
return (value,)
22+
def post_processing(
23+
self, pre_processing_output: S, Y: np.ndarray
24+
) -> np.ndarray:
25+
return Y
2326

2427

25-
class PrePostProcessingDMD:
28+
class PrePostProcessingDMD(Generic[S]):
2629
"""
2730
Pre/post-processing decorator. This class is not thread-safe in case of
2831
stateful transformations.
@@ -40,20 +43,14 @@ class PrePostProcessingDMD:
4043
def __init__(
4144
self,
4245
dmd: DMDBase,
43-
pre_processing: Callable = _shallow_preprocessing,
44-
post_processing: Callable = _shallow_postprocessing,
46+
pre_post_processing: PrePostProcessing[S] = PrePostProcessing(),
4547
):
4648
if dmd is None:
4749
raise ValueError("DMD instance cannot be None")
48-
if pre_processing is None:
49-
pre_processing = _shallow_preprocessing
50-
if post_processing is None:
51-
post_processing = _shallow_postprocessing
50+
self._pre_post_processing = pre_post_processing
5251

53-
self._pre_post_processed_dmd = dmd
54-
self._pre_processing = pre_processing
55-
self._post_processing = post_processing
56-
self._state_holder = None
52+
self._dmd = dmd
53+
self._pre_processing_output: S | None = None
5754

5855
def __getattribute__(self, name):
5956
try:
@@ -65,18 +62,14 @@ def __getattribute__(self, name):
6562
return self._pre_processing_fit
6663

6764
if "reconstructed_data" == name:
68-
output = self._post_processing(
69-
self._state_holder,
70-
self._pre_post_processed_dmd.reconstructed_data,
71-
)
72-
return output
65+
return self._reconstructed_data_with_post_processing()
7366

7467
# This check is needed to allow copy/deepcopy
75-
if name != "_pre_post_processed_dmd":
76-
sub_dmd = self._pre_post_processed_dmd
68+
if name != "_dmd":
69+
sub_dmd = self._dmd
7770
if isinstance(sub_dmd, PrePostProcessingDMD):
7871
return PrePostProcessingDMD.__getattribute__(sub_dmd, name)
79-
return object.__getattribute__(self._pre_post_processed_dmd, name)
72+
return object.__getattribute__(self._dmd, name)
8073
return None
8174

8275
@property
@@ -87,19 +80,61 @@ def pre_post_processed_dmd(self):
8780
:return: decorated DMD instance.
8881
:rtype: pydmd.DMDBase
8982
"""
90-
return self._pre_post_processed_dmd
83+
return self._dmd
9184

9285
@property
9386
def modes_activation_bitmask(self):
94-
return self._pre_post_processed_dmd.modes_activation_bitmask
87+
return self._dmd.modes_activation_bitmask
9588

9689
@modes_activation_bitmask.setter
9790
def modes_activation_bitmask(self, value):
98-
self._pre_post_processed_dmd.modes_activation_bitmask = value
91+
self._dmd.modes_activation_bitmask = value
9992

10093
def _pre_processing_fit(self, *args, **kwargs):
101-
self._state_holder = dict()
102-
pre_processing_output = _tuplify(
103-
self._pre_processing(self._state_holder, *args, **kwargs)
94+
X = PrePostProcessingDMD._extract_training_data(*args, **kwargs)
95+
self._pre_processing_output, pre_processed_training_data = (
96+
self._pre_post_processing.pre_processing(X)
97+
)
98+
new_args, new_kwargs = PrePostProcessingDMD._replace_training_data(
99+
pre_processed_training_data, *args, **kwargs
104100
)
105-
return self._pre_post_processed_dmd.fit(*pre_processing_output)
101+
return self._dmd.fit(*new_args, **new_kwargs)
102+
103+
def _reconstructed_data_with_post_processing(self) -> np.ndarray:
104+
data = self._dmd.reconstructed_data
105+
106+
if not isroutine(data):
107+
return self._pre_post_processing.post_processing(
108+
self._pre_processing_output,
109+
data,
110+
)
111+
112+
# e.g. DMDc
113+
def output(*args, **kwargs) -> np.ndarray:
114+
return self._pre_post_processing.post_processing(
115+
self._pre_processing_output,
116+
data(*args, **kwargs),
117+
)
118+
119+
return output
120+
121+
@staticmethod
122+
def _extract_training_data(*args, **kwargs):
123+
if len(args) >= 1:
124+
return args[0]
125+
elif "X" in kwargs:
126+
return kwargs["X"]
127+
raise ValueError(
128+
f"Could not extract training data from {args=}, {kwargs=}"
129+
)
130+
131+
@staticmethod
132+
def _replace_training_data(
133+
new_training_data: Any, *args, **kwargs
134+
) -> [Tuple[Any, ...], Dict[str, Any]]:
135+
if len(args) >= 1:
136+
return (new_training_data,) + args[1:], kwargs
137+
elif "X" in kwargs:
138+
new_kwargs = dict(kwargs)
139+
new_kwargs["X"] = new_training_data
140+
return args, new_kwargs

0 commit comments

Comments
 (0)