22Pre/post-processing capability for DMD instances.
33"""
44
5- from typing import Callable , Dict
5+ from __future__ import annotations
66
7- from pydmd .dmdbase import DMDBase
7+ from inspect import isroutine
8+ from typing import Any , Dict , Generic , Tuple , TypeVar
89
10+ import numpy as np
911
10- def _shallow_preprocessing (_ : Dict , * args , ** kwargs ):
11- return args + tuple (kwargs .values ())
12+ from pydmd .dmdbase import DMDBase
1213
14+ # Pre-processing output type
15+ S = TypeVar ("S" )
1316
14- def _shallow_postprocessing (_ : Dict , * args ):
15- # The first item of args is always the output of dmd.reconstructed_data
16- return args [0 ]
1717
18+ class PrePostProcessing (Generic [S ]):
19+ def pre_processing (self , X : np .ndarray ) -> Tuple [S , np .ndarray ]:
20+ return None , X
1821
19- def _tuplify ( value ):
20- if isinstance ( value , tuple ):
21- return value
22- return ( value ,)
22+ def post_processing (
23+ self , pre_processing_output : S , Y : np . ndarray
24+ ) -> np . ndarray :
25+ return Y
2326
2427
25- class PrePostProcessingDMD :
28+ class PrePostProcessingDMD ( Generic [ S ]) :
2629 """
2730 Pre/post-processing decorator. This class is not thread-safe in case of
2831 stateful transformations.
@@ -40,20 +43,14 @@ class PrePostProcessingDMD:
4043 def __init__ (
4144 self ,
4245 dmd : DMDBase ,
43- pre_processing : Callable = _shallow_preprocessing ,
44- post_processing : Callable = _shallow_postprocessing ,
46+ pre_post_processing : PrePostProcessing [S ] = PrePostProcessing (),
4547 ):
4648 if dmd is None :
4749 raise ValueError ("DMD instance cannot be None" )
48- if pre_processing is None :
49- pre_processing = _shallow_preprocessing
50- if post_processing is None :
51- post_processing = _shallow_postprocessing
50+ self ._pre_post_processing = pre_post_processing
5251
53- self ._pre_post_processed_dmd = dmd
54- self ._pre_processing = pre_processing
55- self ._post_processing = post_processing
56- self ._state_holder = None
52+ self ._dmd = dmd
53+ self ._pre_processing_output : S | None = None
5754
5855 def __getattribute__ (self , name ):
5956 try :
@@ -65,18 +62,14 @@ def __getattribute__(self, name):
6562 return self ._pre_processing_fit
6663
6764 if "reconstructed_data" == name :
68- output = self ._post_processing (
69- self ._state_holder ,
70- self ._pre_post_processed_dmd .reconstructed_data ,
71- )
72- return output
65+ return self ._reconstructed_data_with_post_processing ()
7366
7467 # This check is needed to allow copy/deepcopy
75- if name != "_pre_post_processed_dmd " :
76- sub_dmd = self ._pre_post_processed_dmd
68+ if name != "_dmd " :
69+ sub_dmd = self ._dmd
7770 if isinstance (sub_dmd , PrePostProcessingDMD ):
7871 return PrePostProcessingDMD .__getattribute__ (sub_dmd , name )
79- return object .__getattribute__ (self ._pre_post_processed_dmd , name )
72+ return object .__getattribute__ (self ._dmd , name )
8073 return None
8174
8275 @property
@@ -87,19 +80,61 @@ def pre_post_processed_dmd(self):
8780 :return: decorated DMD instance.
8881 :rtype: pydmd.DMDBase
8982 """
90- return self ._pre_post_processed_dmd
83+ return self ._dmd
9184
9285 @property
9386 def modes_activation_bitmask (self ):
94- return self ._pre_post_processed_dmd .modes_activation_bitmask
87+ return self ._dmd .modes_activation_bitmask
9588
9689 @modes_activation_bitmask .setter
9790 def modes_activation_bitmask (self , value ):
98- self ._pre_post_processed_dmd .modes_activation_bitmask = value
91+ self ._dmd .modes_activation_bitmask = value
9992
10093 def _pre_processing_fit (self , * args , ** kwargs ):
101- self ._state_holder = dict ()
102- pre_processing_output = _tuplify (
103- self ._pre_processing (self ._state_holder , * args , ** kwargs )
94+ X = PrePostProcessingDMD ._extract_training_data (* args , ** kwargs )
95+ self ._pre_processing_output , pre_processed_training_data = (
96+ self ._pre_post_processing .pre_processing (X )
97+ )
98+ new_args , new_kwargs = PrePostProcessingDMD ._replace_training_data (
99+ pre_processed_training_data , * args , ** kwargs
104100 )
105- return self ._pre_post_processed_dmd .fit (* pre_processing_output )
101+ return self ._dmd .fit (* new_args , ** new_kwargs )
102+
103+ def _reconstructed_data_with_post_processing (self ) -> np .ndarray :
104+ data = self ._dmd .reconstructed_data
105+
106+ if not isroutine (data ):
107+ return self ._pre_post_processing .post_processing (
108+ self ._pre_processing_output ,
109+ data ,
110+ )
111+
112+ # e.g. DMDc
113+ def output (* args , ** kwargs ) -> np .ndarray :
114+ return self ._pre_post_processing .post_processing (
115+ self ._pre_processing_output ,
116+ data (* args , ** kwargs ),
117+ )
118+
119+ return output
120+
121+ @staticmethod
122+ def _extract_training_data (* args , ** kwargs ):
123+ if len (args ) >= 1 :
124+ return args [0 ]
125+ elif "X" in kwargs :
126+ return kwargs ["X" ]
127+ raise ValueError (
128+ f"Could not extract training data from { args = } , { kwargs = } "
129+ )
130+
131+ @staticmethod
132+ def _replace_training_data (
133+ new_training_data : Any , * args , ** kwargs
134+ ) -> [Tuple [Any , ...], Dict [str , Any ]]:
135+ if len (args ) >= 1 :
136+ return (new_training_data ,) + args [1 :], kwargs
137+ elif "X" in kwargs :
138+ new_kwargs = dict (kwargs )
139+ new_kwargs ["X" ] = new_training_data
140+ return args , new_kwargs
0 commit comments