From 92ad881b2a32aa95076e6aa322377e50ffcc60ee Mon Sep 17 00:00:00 2001 From: theissenhelen Date: Tue, 28 May 2024 15:14:50 +0000 Subject: [PATCH] fix: argument order and explicit types --- src/anemoi/models/preprocessing/imputer.py | 11 +++++++---- src/anemoi/models/preprocessing/normalizer.py | 3 ++- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/src/anemoi/models/preprocessing/imputer.py b/src/anemoi/models/preprocessing/imputer.py index 34ad667..a2e5bd6 100644 --- a/src/anemoi/models/preprocessing/imputer.py +++ b/src/anemoi/models/preprocessing/imputer.py @@ -13,6 +13,7 @@ import torch +from anemoi.models.data_indices.collection import IndexCollection from anemoi.models.preprocessing import BasePreprocessor LOGGER = logging.getLogger(__name__) @@ -24,8 +25,8 @@ class BaseImputer(BasePreprocessor, ABC): def __init__( self, config=None, + data_indices: Optional[IndexCollection] = None, statistics: Optional[dict] = None, - data_indices: Optional[dict] = None, ) -> None: """Initialize the imputer. @@ -176,7 +177,7 @@ def __init__( statistics: Optional[dict] = None, data_indices: Optional[dict] = None, ) -> None: - super().__init__(config, statistics, data_indices) + super().__init__(config, data_indices, statistics) self._create_imputation_indices(statistics) @@ -199,8 +200,10 @@ class ConstantImputer(BaseImputer): ``` """ - def __init__(self, config=None, statistics: Optional[dict] = None, data_indices: Optional[dict] = None) -> None: - super().__init__(config, statistics, data_indices) + def __init__( + self, config=None, statistics: Optional[dict] = None, data_indices: Optional[IndexCollection] = None + ) -> None: + super().__init__(config, data_indices, statistics) self._create_imputation_indices() diff --git a/src/anemoi/models/preprocessing/normalizer.py b/src/anemoi/models/preprocessing/normalizer.py index 6326cb1..8a7dd61 100644 --- a/src/anemoi/models/preprocessing/normalizer.py +++ b/src/anemoi/models/preprocessing/normalizer.py @@ -14,6 +14,7 @@ import numpy as np import torch +from anemoi.models.data_indices.collection import IndexCollection from anemoi.models.preprocessing import BasePreprocessor LOGGER = logging.getLogger(__name__) @@ -25,7 +26,7 @@ class InputNormalizer(BasePreprocessor): def __init__( self, config=None, - data_indices: Optional[dict] = None, + data_indices: Optional[IndexCollection] = None, statistics: Optional[dict] = None, ) -> None: """Initialize the normalizer.