diff --git a/src/signature_mahalanobis_knn/sig_mahal_knn.py b/src/signature_mahalanobis_knn/sig_mahal_knn.py index 38c6dbc..5e4d97c 100644 --- a/src/signature_mahalanobis_knn/sig_mahal_knn.py +++ b/src/signature_mahalanobis_knn/sig_mahal_knn.py @@ -106,17 +106,25 @@ def fit( ) # set default kwargs for signature transformer if not provided - if signature_kwargs is None or signature_kwargs == {}: - signature_kwargs = { - "augmentation_list": ("addtime",), - "window_name": "global", - "window_depth": None, - "window_length": None, - "window_step": None, - "rescaling": None, - "sig_tfm": "signature", - "depth": 2, - } + sig_defaults = { + "augmentation_list": ("addtime",), + "window_name": "global", + "window_depth": None, + "window_length": None, + "window_step": None, + "rescaling": None, + "sig_tfm": "signature", + "depth": 2, + } + + if signature_kwargs is None: + # set all defaults + signature_kwargs = sig_defaults + else: + # set defaults for any missing kwargs + for key, value in sig_defaults.items(): + if key not in signature_kwargs: + signature_kwargs[key] = value self.signature_transform = SignatureTransformer( **signature_kwargs,