diff --git a/hsf/augmentation.py b/hsf/augmentation.py index c3c75e1..7c8a693 100644 --- a/hsf/augmentation.py +++ b/hsf/augmentation.py @@ -62,5 +62,4 @@ def get_augmented_subject(subject: tio.Subject, augmentation_cfg: DictConfig, subjects.append(subject) return subjects - else: - return [subject] + return [subject] diff --git a/hsf/engines.py b/hsf/engines.py index 809ab49..df599b2 100644 --- a/hsf/engines.py +++ b/hsf/engines.py @@ -30,15 +30,14 @@ def deepsparse_support() -> str: if vnni: # Optimal for int8 quantized NN return "full" - elif avx512: + if avx512: # AVX512 vector instruction set for fast NN inference return "partial" - elif avx2: + if avx2: # AVX2 vector instruction set slower than AVX512 return "minimal" - else: - # No AVX2/512 or VNNI -> Risk of slow inference time - return "not supported" + # No AVX2/512 or VNNI -> Risk of slow inference time + return "not supported" def print_deepsparse_support(): @@ -107,7 +106,7 @@ def __call__(self, x): assert len(feed_names) == 1, "Only one input is supported" return self.engine.run(None, {feed_names[0]: x}) - elif self.engine_name == "deepsparse": + if self.engine_name == "deepsparse": return self.engine.run([x]) def set_deepsparse_engine(self, model: PosixPath): @@ -133,7 +132,7 @@ def set_ort_engine(self, model): def _correct_provider(provider): if isinstance(provider, str): return provider - elif isinstance(provider, ListConfig): + if isinstance(provider, ListConfig): provider = list(provider) assert len(provider) == 2 provider[1] = dict(provider[1]) diff --git a/hsf/fetch_models.py b/hsf/fetch_models.py index 8f6f8e6..bbc89b7 100644 --- a/hsf/fetch_models.py +++ b/hsf/fetch_models.py @@ -50,9 +50,8 @@ def fetch(directory: str, filename: str, url: str, xxh3_64: str) -> None: if get_hash(str(outfile)) == xxh3_64: log.info(f"{filename} already exists and is up to date") return - else: - log.info(f"{filename} already exists but is not up to date") - outfile.unlink() + log.info(f"{filename} already exists but is not up to date") + outfile.unlink() log.info(f"Fetching {url}") wget.download(url, out=str(outfile)) diff --git a/hsf/multispectrality.py b/hsf/multispectrality.py index 91c1764..7bfe95f 100644 --- a/hsf/multispectrality.py +++ b/hsf/multispectrality.py @@ -38,8 +38,7 @@ def get_second_contrast(mri: PosixPath, pattern: str) -> Optional[PosixPath]: ) == 1, f"Invalid file pattern: {pattern}. No or multiple files found." return second_contrast[0] - else: - return None + return None @handle_cache @@ -61,33 +60,32 @@ def register(mri: PosixPath, """ if cfg.multispectrality.same_space: return second_contrast - else: - registration_params = dict(cfg.multispectrality.registration) - if not registration_params.get("outprefix"): - registration_params["outprefix"] = outprefix + registration_params = dict(cfg.multispectrality.registration) + if not registration_params.get("outprefix"): + registration_params["outprefix"] = outprefix - fixed = ants.image_read(str(mri)) - moving = ants.image_read(str(second_contrast)) + fixed = ants.image_read(str(mri)) + moving = ants.image_read(str(second_contrast)) - log.info(f"Registering {str(second_contrast)} to {str(mri)}") - transformation = ants.registration(fixed=fixed, - moving=moving, - **registration_params) + log.info(f"Registering {str(second_contrast)} to {str(mri)}") + transformation = ants.registration(fixed=fixed, + moving=moving, + **registration_params) - registered = ants.apply_transforms( - fixed=fixed, - moving=moving, - transformlist=transformation["fwdtransforms"]) + registered = ants.apply_transforms( + fixed=fixed, + moving=moving, + transformlist=transformation["fwdtransforms"]) - extensions = "".join(second_contrast.suffixes) - fname = second_contrast.name.replace(extensions, - "") + "_registered.nii.gz" - output_dir = mri.parent / cfg.files.output_dir - output_dir.mkdir(parents=True, exist_ok=True) + extensions = "".join(second_contrast.suffixes) + fname = second_contrast.name.replace(extensions, + "") + "_registered.nii.gz" + output_dir = mri.parent / cfg.files.output_dir + output_dir.mkdir(parents=True, exist_ok=True) - ants.image_write(registered, str(output_dir / fname)) + ants.image_write(registered, str(output_dir / fname)) - return output_dir / fname + return output_dir / fname def get_additional_hippocampi(mri: PosixPath, second_contrast: PosixPath, diff --git a/hsf/segment.py b/hsf/segment.py index f1f2fb6..fab830c 100644 --- a/hsf/segment.py +++ b/hsf/segment.py @@ -53,17 +53,17 @@ def to_ca_mode(logits: torch.Tensor, ca_mode: str = "1/2/3") -> torch.Tensor: _in = torch.sum(logits[:, 1:, :, :, :], dim=1, keepdim=True) return torch.cat([_pre, _in], dim=1) - elif ca_mode == "1/2/3": + if ca_mode == "1/2/3": # identity return logits - elif ca_mode == "1/23": + if ca_mode == "1/23": # ca1; ca2+ca3 _pre = logits[:, :3, :, :, :] _in = logits[:, 3:4, :, :, :] + logits[:, 4:5, :, :, :] _post = logits[:, 5:, :, :, :] return torch.cat([_pre, _in, _post], dim=1) - elif ca_mode == "123": + if ca_mode == "123": # ca1+ca2+ca3 _pre = logits[:, :2, :, :, :] _in = logits[:, @@ -73,10 +73,9 @@ def to_ca_mode(logits: torch.Tensor, ca_mode: str = "1/2/3") -> torch.Tensor: _post = logits[:, 5:, :, :, :] return torch.cat([_pre, _in, _post], dim=1) - else: - raise ValueError( - f"Unknown `ca_mode` ({ca_mode}). `ca_mode` must be 1/2/3, 1/23 or 123" - ) + raise ValueError( + f"Unknown `ca_mode` ({ca_mode}). `ca_mode` must be 1/2/3, 1/23 or 123" + ) def predict(mris: list,