Skip to content

Commit

Permalink
style: format code with isort and Yapf
Browse files Browse the repository at this point in the history
This commit fixes the style issues introduced in 540b33f according to the output
from isort and Yapf.

Details: #32
  • Loading branch information
deepsource-autofix[bot] authored Mar 29, 2024
1 parent 540b33f commit 5ac4253
Showing 1 changed file with 10 additions and 9 deletions.
19 changes: 10 additions & 9 deletions hsf/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
log = logging.getLogger(__name__)



def get_lr_hippocampi(mri: PosixPath, cfg: DictConfig) -> tuple:
"""
Get left and right hippocampi from a given MRI.
Expand Down Expand Up @@ -63,8 +62,8 @@ def get_lr_hippocampi(mri: PosixPath, cfg: DictConfig) -> tuple:
original_mri_path=mri)


def predict(mri: PosixPath, second_mri: Optional[PosixPath], engines: Generator,
cfg: DictConfig) -> tuple:
def predict(mri: PosixPath, second_mri: Optional[PosixPath],
engines: Generator, cfg: DictConfig) -> tuple:
"""
Predict the hippocampal segmentation for a given MRI.
Expand Down Expand Up @@ -156,7 +155,8 @@ def filter_mris(mris: List[PosixPath], overwrite: bool) -> List[PosixPath]:
def _get_segmentations(mri: PosixPath) -> List[PosixPath]:
extensions = "".join(mri.suffixes)
stem = mri.name.replace(extensions, "")
segmentations = list(mri.parent.glob(f"{stem}*_hippocampus_seg.nii.gz"))
segmentations = list(
mri.parent.glob(f"{stem}*_hippocampus_seg.nii.gz"))

if len(segmentations) > 2:
log.warning(
Expand Down Expand Up @@ -189,10 +189,10 @@ def main(cfg: DictConfig) -> None:
bs = cfg.hardware.engine_settings.batch_size
multispectral = 2 if cfg.multispectrality.pattern else 1

if multispectral * (
tta + 1
) % bs != 0:
raise AssertionError("test_time_num_aug+1 must be a multiple of batch_size for deepsparse")
if multispectral * (tta + 1) % bs != 0:
raise AssertionError(
"test_time_num_aug+1 must be a multiple of batch_size for deepsparse"
)

mris = load_from_config(cfg.files.path, cfg.files.pattern)
_n = len(mris)
Expand Down Expand Up @@ -230,7 +230,8 @@ def main(cfg: DictConfig) -> None:
else:
additional_hippocampi = [None, None]

for j, hippocampus in enumerate(zip(hippocampi, additional_hippocampi)):
for j, hippocampus in enumerate(zip(hippocampi,
additional_hippocampi)):
engines = get_inference_engines(
cfg.segmentation.models_path,
engine_name=cfg.hardware.engine,
Expand Down

0 comments on commit 5ac4253

Please sign in to comment.