diff --git a/src/pd_dwi/config/config.py b/src/pd_dwi/config/config.py index 12da7a9..9ccbba8 100644 --- a/src/pd_dwi/config/config.py +++ b/src/pd_dwi/config/config.py @@ -96,4 +96,11 @@ def validate_encoders_dataset(self): if not masks.issubset(self.dataset.masks): raise ValueError("Encoders contain masks that are not available in dataset") + time_points = self.pipeline.features_transformer.radiomics.encoders[0].time_points + for e in self.pipeline.features_transformer.radiomics.encoders[1:]: + time_points = time_points.union(e.time_points) + + if not time_points.issubset(self.dataset.time_points): + raise ValueError("Encoders contain time points that are not available in dataset") + return self diff --git a/tests/data/invalid_encoders_bad_time_points.yaml b/tests/data/invalid_encoders_bad_time_points.yaml new file mode 100644 index 0000000..cbef0c5 --- /dev/null +++ b/tests/data/invalid_encoders_bad_time_points.yaml @@ -0,0 +1,57 @@ +dataset: + labels: + negative: Non-pCR + positive: pCR + + time_points: + - T0 + - T2 + + modalities: + - ADC 0100 + - F + + masks: + - DWI MASK + +pipeline: + features_transformer: + radiomics: + encoders: # encoder names format is {time point}_{image} + - image: ADC 0100 + mask: DWI MASK + time_points: + - T0 + - T1 + - T2 + - image: F + mask: DWI MASK + time_points: + - T0 + - T1 + engine: + imageType: + Original: { } + setting: + resampledPixelSpacing: [ 1.27,1.27,4.0 ] + feature_selection: + k: 100 + classifier: + module: samples.XGBClassifier + parameters: + random_state: 42 + use_label_encoder: False + validate_parameters: True + learning_rate: 0.01 + n_estimators: 1000 + max_depth: 20 + min_child_weight: 10 + +grid_search_cv: + verbose: 3 + scoring: roc_auc + cv: 5 + param_grid: + classifier: + scale_pos_weight: + - 2.22 # balanced \ No newline at end of file diff --git a/tests/test_config.py b/tests/test_config.py index 253d137..e76ed14 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -20,7 +20,8 @@ def test_from_config_valid(): 'config_name,validation_msg', [ ('invalid_no_modalities', 'Field required'), - ('invalid_encoders_bad_modality', 'Value error, Encoders contain modalities that are not available in dataset') + ('invalid_encoders_bad_modality', 'Value error, Encoders contain modalities that are not available in dataset'), + ('invalid_encoders_bad_time_points', 'Value error, Encoders contain time points that are not available in dataset') ] ) def test_from_config_invalid(config_name, validation_msg):