Skip to content

Commit

Permalink
Merge pull request #32 from clementpoiret/deepsource-autofix-9e5ccf0f
Browse files Browse the repository at this point in the history
refactor: remove assert statement from non-test files
  • Loading branch information
clementpoiret authored Mar 29, 2024
2 parents 3dcd9b0 + 5ac4253 commit 5fb1d84
Showing 1 changed file with 10 additions and 8 deletions.
18 changes: 10 additions & 8 deletions hsf/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
log = logging.getLogger(__name__)



def get_lr_hippocampi(mri: PosixPath, cfg: DictConfig) -> tuple:
"""
Get left and right hippocampi from a given MRI.
Expand Down Expand Up @@ -63,8 +62,8 @@ def get_lr_hippocampi(mri: PosixPath, cfg: DictConfig) -> tuple:
original_mri_path=mri)


def predict(mri: PosixPath, second_mri: Optional[PosixPath], engines: Generator,
cfg: DictConfig) -> tuple:
def predict(mri: PosixPath, second_mri: Optional[PosixPath],
engines: Generator, cfg: DictConfig) -> tuple:
"""
Predict the hippocampal segmentation for a given MRI.
Expand Down Expand Up @@ -156,7 +155,8 @@ def filter_mris(mris: List[PosixPath], overwrite: bool) -> List[PosixPath]:
def _get_segmentations(mri: PosixPath) -> List[PosixPath]:
extensions = "".join(mri.suffixes)
stem = mri.name.replace(extensions, "")
segmentations = list(mri.parent.glob(f"{stem}*_hippocampus_seg.nii.gz"))
segmentations = list(
mri.parent.glob(f"{stem}*_hippocampus_seg.nii.gz"))

if len(segmentations) > 2:
log.warning(
Expand Down Expand Up @@ -189,9 +189,10 @@ def main(cfg: DictConfig) -> None:
bs = cfg.hardware.engine_settings.batch_size
multispectral = 2 if cfg.multispectrality.pattern else 1

assert multispectral * (
tta + 1
) % bs == 0, "test_time_num_aug+1 must be a multiple of batch_size for deepsparse"
if multispectral * (tta + 1) % bs != 0:
raise AssertionError(
"test_time_num_aug+1 must be a multiple of batch_size for deepsparse"
)

mris = load_from_config(cfg.files.path, cfg.files.pattern)
_n = len(mris)
Expand Down Expand Up @@ -229,7 +230,8 @@ def main(cfg: DictConfig) -> None:
else:
additional_hippocampi = [None, None]

for j, hippocampus in enumerate(zip(hippocampi, additional_hippocampi)):
for j, hippocampus in enumerate(zip(hippocampi,
additional_hippocampi)):
engines = get_inference_engines(
cfg.segmentation.models_path,
engine_name=cfg.hardware.engine,
Expand Down

0 comments on commit 5fb1d84

Please sign in to comment.