From cde2998ff46cdfa3100486c8065d5dbf0d086a2d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sophia=20M=C3=A4dler?= <15019107+sophiamaedler@users.noreply.github.com> Date: Mon, 27 Jan 2025 01:51:53 +0100 Subject: [PATCH] read masks from hdf5 if not already provided --- src/scportrait/pipeline/featurization.py | 129 +++++++++++++++-------- 1 file changed, 85 insertions(+), 44 deletions(-) diff --git a/src/scportrait/pipeline/featurization.py b/src/scportrait/pipeline/featurization.py index 39d45dd9..32089592 100644 --- a/src/scportrait/pipeline/featurization.py +++ b/src/scportrait/pipeline/featurization.py @@ -6,6 +6,7 @@ from functools import partial as func_partial import numpy as np +import h5py import pandas as pd import pytorch_lightning as pl import torch @@ -17,7 +18,6 @@ from scportrait.tools.ml.datasets import HDF5SingleCellDataset from scportrait.tools.ml.plmodels import MultilabelSupervisedModel - class _FeaturizationBase(ProcessingStep): PRETRAINED_MODEL_NAMES = [ "autophagy_classifier", @@ -170,10 +170,22 @@ def _setup_inference_device(self): self.inference_device = self._detect_automatic_inference_device() self.log(f"Automatically configured inferece device to {self.inference_device}") - def _general_setup(self): + def _get_nmasks(self): + if "n_masks" not in self.__dict__.keys(): + try: + self.n_masks = h5py.File(self.extraction_file, "r")["n_masks"][()].item() + except Exception as e: + raise ValueError( + f"Could not extract number of masks from HDF5 file. Error: {e}" + ) from e + + def _general_setup(self, extraction_dir: str, return_results: bool = False): """Helper function to execute all setup functions that are common to all featurization steps.""" - self._setup_output() + self.extraction_file = extraction_dir + if not return_results: + self._setup_output() + self._get_nmasks() self._setup_log_transform() self._setup_inference_device() @@ -784,8 +796,8 @@ def _setup_transforms(self) -> None: return - def _setup(self): - self._general_setup() + def _setup(self, extraction_dir: str, return_results: bool): + self._general_setup(extraction_dir=extraction_dir, return_results=return_results) self._get_model_specs() self._get_network_dir() @@ -803,7 +815,7 @@ def _setup(self): self._setup_encoders() self._setup_transforms() - def process(self, extraction_dir: str, size: int = 0): + def process(self, extraction_dir: str, size: int = 0, return_results: bool = False): """ Perform classification on the provided HDF5 dataset. @@ -880,7 +892,7 @@ class based on the previous single-cell extraction. Therefore, only the second a self.log("Started MLClusterClassifier classification.") # perform setup - self._setup() + self._setup(extraction_dir = extraction_dir, return_results=return_results) self.dataloader = self.generate_dataloader( extraction_dir, @@ -890,21 +902,28 @@ class based on the previous single-cell extraction. Therefore, only the second a ) # perform inference + all_results = [] for model in self.models: self.log(f"Starting inference for model encoder {model.__name__}") results = self.inference(self.dataloader, model) - output_name = f"inference_{model.__name__}" - path = os.path.join(self.run_path, f"{output_name}.csv") + if not return_results: + output_name = f"inference_{model.__name__}" + path = os.path.join(self.run_path, f"{output_name}.csv") - self._write_results_csv(results, path) - self._write_results_sdata(results, label=f"{self.label}_{model.__name__}") - - self.log(f"Results saved to file: {path}") + self._write_results_csv(results, path) + self._write_results_sdata(results, label=f"{self.label}_{model.__name__}") + else: + all_results.append(results) - # perform post processing cleanup - if not self.deep_debug: - self._post_processing_cleanup() + if return_results: + self._clear_cache() + return all_results + else: + self.log(f"Results saved to file: {path}") + # perform post processing cleanup + if not self.deep_debug: + self._post_processing_cleanup() class EnsembleClassifier(_FeaturizationBase): @@ -956,8 +975,8 @@ def _load_models(self): memory_usage = self._get_gpu_memory_usage() self.log(f"GPU memory usage after loading models: {memory_usage}") - def _setup(self): - self._general_setup() + def _setup(self, extraction_dir: str): + self._general_set(extraction_dir=extraction_dir) self._get_model_specs() self._setup_transforms() @@ -969,7 +988,7 @@ def _setup(self): self._load_models() - def process(self, extraction_dir, size=0): + def process(self, extraction_dir:str, size:int = 0, return_results:bool = False): """ Function called to perform classification on the provided HDF5 dataset. @@ -1024,7 +1043,7 @@ class based on the previous single-cell extraction. Therefore, no parameters nee self.log("Starting Ensemble Classification") - self._setup() + self._setup(extraction_dir=extraction_dir, return_results=return_results) self.dataloader = self.generate_dataloader( extraction_dir, @@ -1034,19 +1053,28 @@ class based on the previous single-cell extraction. Therefore, no parameters nee ) # perform inference + all_results = {} for model_name, model in zip(self.model_names, self.model, strict=False): self.log(f"Starting inference for model {model_name}") results = self.inference(self.dataloader, model) output_name = f"ensemble_inference_{model_name}" - path = os.path.join(self.run_path, f"{output_name}.csv") - self._write_results_csv(results, path) - self._write_results_sdata(results, label=model_name) + if not return_results: + path = os.path.join(self.run_path, f"{output_name}.csv") - # perform post processing cleanup - if not self.deep_debug: - self._post_processing_cleanup() + self._write_results_csv(results, path) + self._write_results_sdata(results, label=model_name) + else: + all_results[model_name] = results + + if return_results: + self._clear_cache() + return all_results + else: + # perform post processing cleanup + if not self.deep_debug: + self._post_processing_cleanup() ####### CellFeaturization based on Classic Featurecalculation ####### @@ -1083,10 +1111,19 @@ def _setup_transforms(self): return def _get_channel_specs(self): - if "channel_names" in self.project.__dict__.keys(): - self.channel_names = self.project.channel_names + if self.project is None: + try: + with h5py.File(self.extraction_file, "r") as f: + self.channel_names = list(f["channel_information"][:].astype(str)) + except Exception as e: + raise ValueError( + f"Could not extract channel names from HDF5 file. Please provide channel names manually. Error: {e}" + ) from e else: - self.channel_names = self.project.input_image.c.values + if "channel_names" in self.project.__dict__.keys(): + self.channel_names = self.project.channel_names + else: + self.channel_names = self.project.input_image.c.values def _generate_column_names( self, @@ -1298,12 +1335,12 @@ def __init__(self, *args, **kwargs): self.channel_selection = None # ensure that all images are passed to the function - def _setup(self): - self._general_setup() + def _setup(self, extraction_dir:str, return_results:bool): + self._general_setup(extraction_dir=extraction_dir, return_results=return_results) self._setup_transforms() self._get_channel_specs() - def process(self, extraction_dir, size=0): + def process(self, extraction_dir: str, size: int =0, return_results: bool = False): """ Perform featurization on the provided HDF5 dataset. @@ -1358,7 +1395,7 @@ def process(self, extraction_dir, size=0): self.log("Started CellFeaturization of all available channels.") # perform setup - self._setup() + self._setup(extraction_dir=extraction_dir, return_results=return_results) self.dataloader = self.generate_dataloader( extraction_dir, @@ -1388,15 +1425,19 @@ def process(self, extraction_dir, size=0): column_names=self.column_names, ) - output_name = "calculated_image_features" - path = os.path.join(self.run_path, f"{output_name}.csv") + if return_results: + self._clear_cache() + return results + else: + output_name = "calculated_image_features" + path = os.path.join(self.run_path, f"{output_name}.csv") - self._write_results_csv(results, path) - self._write_results_sdata(results) + self._write_results_csv(results, path) + self._write_results_sdata(results) - # perform post processing cleanup - if not self.deep_debug: - self._post_processing_cleanup() + # perform post processing cleanup + if not self.deep_debug: + self._post_processing_cleanup() class CellFeaturizer_single_channel(_cellFeaturizerBase): @@ -1412,17 +1453,17 @@ def _setup_channel_selection(self): self.channel_selection = [0, self.channel_selection] return - def _setup(self): - self._general_setup() + def _setup(self, extraction_dir:str, return_results:bool): + self._general_setup(extraction_dir=extraction_dir, return_results=return_results) self._setup_channel_selection() self._setup_transforms() self._get_channel_specs() - def process(self, extraction_dir, size=0): + def process(self, extraction_dir, size=0, return_results: bool = False): self.log(f"Started CellFeaturization of selected channel {self.channel_selection}.") # perform setup - self._setup() + self._setup(extraction_dir=extraction_dir, return_results=return_results) self.dataloader = self.generate_dataloader( extraction_dir,