diff --git a/src/scportrait/pipeline/_base.py b/src/scportrait/pipeline/_base.py index fe07217..d241cdb 100644 --- a/src/scportrait/pipeline/_base.py +++ b/src/scportrait/pipeline/_base.py @@ -11,6 +11,7 @@ from scportrait.pipeline._utils.helper import read_config + class Logable: """Create log entries. @@ -179,13 +180,13 @@ class ProcessingStep(Logable): def __init__( self, config, - directory = None, - project_location = None, + directory=None, + project_location=None, debug=False, overwrite=False, project=None, filehandler=None, - from_project:bool = False, + from_project: bool = False, ): super().__init__(directory=directory) @@ -205,7 +206,7 @@ def __init__( if isinstance(config, str): config = read_config(config) if self.__class__.__name__ in config.keys(): - self.config = config[self.__class__.__name__ ] + self.config = config[self.__class__.__name__] else: self.config = config else: diff --git a/src/scportrait/pipeline/_utils/helper.py b/src/scportrait/pipeline/_utils/helper.py index 1f6c012..e930104 100644 --- a/src/scportrait/pipeline/_utils/helper.py +++ b/src/scportrait/pipeline/_utils/helper.py @@ -1,8 +1,10 @@ from typing import TypeVar + import yaml T = TypeVar("T") + def read_config(config_path: str) -> dict: with open(config_path) as stream: try: @@ -11,6 +13,7 @@ def read_config(config_path: str) -> dict: print(exc) return config + def flatten(nested_list: list[list[T]]) -> list[T | tuple[T]]: """Flatten a list of lists into a single list. diff --git a/src/scportrait/pipeline/featurization.py b/src/scportrait/pipeline/featurization.py index 3208959..28124a7 100644 --- a/src/scportrait/pipeline/featurization.py +++ b/src/scportrait/pipeline/featurization.py @@ -5,8 +5,8 @@ from contextlib import redirect_stdout from functools import partial as func_partial -import numpy as np import h5py +import numpy as np import pandas as pd import pytorch_lightning as pl import torch @@ -18,6 +18,7 @@ from scportrait.tools.ml.datasets import HDF5SingleCellDataset from scportrait.tools.ml.plmodels import MultilabelSupervisedModel + class _FeaturizationBase(ProcessingStep): PRETRAINED_MODEL_NAMES = [ "autophagy_classifier", @@ -175,9 +176,7 @@ def _get_nmasks(self): try: self.n_masks = h5py.File(self.extraction_file, "r")["n_masks"][()].item() except Exception as e: - raise ValueError( - f"Could not extract number of masks from HDF5 file. Error: {e}" - ) from e + raise ValueError(f"Could not extract number of masks from HDF5 file. Error: {e}") from e def _general_setup(self, extraction_dir: str, return_results: bool = False): """Helper function to execute all setup functions that are common to all featurization steps.""" @@ -892,7 +891,7 @@ class based on the previous single-cell extraction. Therefore, only the second a self.log("Started MLClusterClassifier classification.") # perform setup - self._setup(extraction_dir = extraction_dir, return_results=return_results) + self._setup(extraction_dir=extraction_dir, return_results=return_results) self.dataloader = self.generate_dataloader( extraction_dir, @@ -975,8 +974,8 @@ def _load_models(self): memory_usage = self._get_gpu_memory_usage() self.log(f"GPU memory usage after loading models: {memory_usage}") - def _setup(self, extraction_dir: str): - self._general_set(extraction_dir=extraction_dir) + def _setup(self, extraction_dir: str, return_results: bool): + self._general_setup(extraction_dir=extraction_dir, return_results=return_results) self._get_model_specs() self._setup_transforms() @@ -988,7 +987,7 @@ def _setup(self, extraction_dir: str): self._load_models() - def process(self, extraction_dir:str, size:int = 0, return_results:bool = False): + def process(self, extraction_dir: str, size: int = 0, return_results: bool = False): """ Function called to perform classification on the provided HDF5 dataset. @@ -1335,12 +1334,12 @@ def __init__(self, *args, **kwargs): self.channel_selection = None # ensure that all images are passed to the function - def _setup(self, extraction_dir:str, return_results:bool): + def _setup(self, extraction_dir: str, return_results: bool): self._general_setup(extraction_dir=extraction_dir, return_results=return_results) self._setup_transforms() self._get_channel_specs() - def process(self, extraction_dir: str, size: int =0, return_results: bool = False): + def process(self, extraction_dir: str, size: int = 0, return_results: bool = False): """ Perform featurization on the provided HDF5 dataset. @@ -1453,7 +1452,7 @@ def _setup_channel_selection(self): self.channel_selection = [0, self.channel_selection] return - def _setup(self, extraction_dir:str, return_results:bool): + def _setup(self, extraction_dir: str, return_results: bool): self._general_setup(extraction_dir=extraction_dir, return_results=return_results) self._setup_channel_selection() self._setup_transforms() diff --git a/src/scportrait/pipeline/project.py b/src/scportrait/pipeline/project.py index 138344d..90f2bc3 100644 --- a/src/scportrait/pipeline/project.py +++ b/src/scportrait/pipeline/project.py @@ -307,7 +307,7 @@ def _setup_featurization_f(self, featurization_f): self.featurization_directory, project_location=self.project_location, debug=self.debug, - overwrite=False, #this needs to be set to false as the featurization step should not remove previously created features + overwrite=False, # this needs to be set to false as the featurization step should not remove previously created features project=self, filehandler=self.filehandler, from_project=True,