Skip to content

Commit

Permalink
Added tests for reading configuration file
Browse files Browse the repository at this point in the history
  • Loading branch information
giladmaya committed Mar 13, 2024
1 parent 3daa2b4 commit 2f79c33
Show file tree
Hide file tree
Showing 16 changed files with 430 additions and 9 deletions.
101 changes: 100 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,13 @@ PyYAML = "^6.0"
scikit-learn = "^1.0.2"
xgboost = "^1.6.1"
pandas = "^1.4.3"
pydantic = "^2.6.4"
pydantic-yaml = "^1.2.1"

[tool.poetry.dev-dependencies]
pytest = "7.1.2"
pytest-cov = "^4.1.0"
pytest-subtests = "^0.12.1"

[build-system]
requires = ["poetry-core>=1.0.0"]
Expand Down
File renamed without changes.
99 changes: 99 additions & 0 deletions src/pd_dwi/config/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
from enum import Enum
from typing import List, Dict, Any, Optional, Set

from pydantic import BaseModel, Field, PositiveInt, NonNegativeInt, root_validator, model_validator


class Labels(BaseModel):
negative: str = Field(min_length=1)
positive: str = Field(min_length=1)


class TimePoint(Enum):
T0 = 'T0'
T1 = 'T1'
T2 = 'T2'


class Modality(Enum):
ADC0100 = 'ADC 0100'
ADC0100600800 = 'ADC 0100600800'
ADC100600800 = 'ADC 100600800'
F = 'F'


class Mask(Enum):
DWI = 'DWI MASK'


class Dataset(BaseModel):
labels: Labels
time_points: Set[TimePoint] = Field(min_items=1)
modalities: Set[Modality] = Field(min_items=1)
masks: Set[Mask]


class RadiomicsFeaturesEncoder(BaseModel):
image: Modality
mask: Mask
time_points: Set[TimePoint] = Field(min_items=1)


class RadiomicsFeaturesTransformer(BaseModel):
encoders: List[RadiomicsFeaturesEncoder]
engine: Dict[str, Any]


class FeatureTransformer(BaseModel):
radiomics: RadiomicsFeaturesTransformer


class FeatureSelection(BaseModel):
k: int


class Classifier(BaseModel):
module: str
parameters: Dict[str, Any]


class Pipeline(BaseModel):
features_transformer: FeatureTransformer
feature_selection: FeatureSelection
classifier: Classifier


class GridSearchParamGrid(BaseModel):
classifier: Dict[str, Any] = None
feature_selection: Dict[str, Any] = None


class GridSearch(BaseModel):
verbose: NonNegativeInt
scoring: str = Field(default='roc_auc')
cv: PositiveInt = Field(default=5)
param_grid: GridSearchParamGrid


class ModelConfig(BaseModel):
dataset: Dataset
pipeline: Pipeline
grid_search_cv: GridSearch = None

class Config:
frozen = True
extra = 'forbid'

@model_validator(mode='after')
def validate_encoders_dataset(self):
modalities = {e.image for e in self.pipeline.features_transformer.radiomics.encoders}

if not modalities.issubset(self.dataset.modalities):
raise ValueError("Encoders contain modalities that are not available in dataset")

masks = {e.mask for e in self.pipeline.features_transformer.radiomics.encoders}

if not masks.issubset(self.dataset.masks):
raise ValueError("Encoders contain masks that are not available in dataset")

return self
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ dataset:

pipeline:
features_transformer:
radiomics:
encoders: # encoder names format is {time point}_{image}
- image: F
mask: DWI MASK
Expand Down
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ dataset:

pipeline:
features_transformer:
radiomics:
encoders: # encoder names format is {time point}_{image}
- image: ADC 100600800
mask: DWI MASK
Expand Down
File renamed without changes.
15 changes: 15 additions & 0 deletions src/pd_dwi/config/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from yaml import load, FullLoader

from pd_dwi.config.config import ModelConfig


def read_config(config):
if hasattr(config, 'read'):
config = load(config, Loader=FullLoader)
elif config.endswith('.yaml') or config.endswith('.yml'):
config = load(open(config), Loader=FullLoader)
else:
raise NotImplementedError()

ModelConfig.model_validate(config)
return config
13 changes: 5 additions & 8 deletions src/pd_dwi/model.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
from pathlib import Path
from pickle import dump, load as pkl_load, HIGHEST_PROTOCOL

import pandas as pd
from jsonschema.validators import validate
from pydantic_yaml import parse_yaml_raw_as
from sklearn.metrics import roc_auc_score
from sklearn.model_selection import GridSearchCV
from yaml import FullLoader, load

from pd_dwi.config.config import ModelConfig
from pd_dwi.config.utils import read_config
from pd_dwi.dataset import create_dataset, validate_dataset
from pd_dwi.training_utils import create_model_from_config

Expand All @@ -17,14 +21,7 @@ def __init__(self, config=None, model_obj=None):

@classmethod
def from_config(cls, config):
if hasattr(config, 'read'):
config = load(config, Loader=FullLoader)
elif not isinstance(config, dict):
raise NotImplementedError()

validate(instance=config, schema=load(open('./configurations/schema.yaml'), Loader=FullLoader))

return cls(config=config)
return cls(config=read_config(config))

def save(self, path):
assert self.model is not None
Expand Down
Loading

0 comments on commit 2f79c33

Please sign in to comment.