From cde2998ff46cdfa3100486c8065d5dbf0d086a2d Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Sophia=20M=C3=A4dler?=
 <15019107+sophiamaedler@users.noreply.github.com>
Date: Mon, 27 Jan 2025 01:51:53 +0100
Subject: [PATCH] read masks from hdf5 if not already provided

---
 src/scportrait/pipeline/featurization.py | 129 +++++++++++++++--------
 1 file changed, 85 insertions(+), 44 deletions(-)

diff --git a/src/scportrait/pipeline/featurization.py b/src/scportrait/pipeline/featurization.py
index 39d45dd9..32089592 100644
--- a/src/scportrait/pipeline/featurization.py
+++ b/src/scportrait/pipeline/featurization.py
@@ -6,6 +6,7 @@
 from functools import partial as func_partial
 
 import numpy as np
+import h5py
 import pandas as pd
 import pytorch_lightning as pl
 import torch
@@ -17,7 +18,6 @@
 from scportrait.tools.ml.datasets import HDF5SingleCellDataset
 from scportrait.tools.ml.plmodels import MultilabelSupervisedModel
 
-
 class _FeaturizationBase(ProcessingStep):
     PRETRAINED_MODEL_NAMES = [
         "autophagy_classifier",
@@ -170,10 +170,22 @@ def _setup_inference_device(self):
             self.inference_device = self._detect_automatic_inference_device()
             self.log(f"Automatically configured inferece device to {self.inference_device}")
 
-    def _general_setup(self):
+    def _get_nmasks(self):
+        if "n_masks" not in self.__dict__.keys():
+            try:
+                self.n_masks = h5py.File(self.extraction_file, "r")["n_masks"][()].item()
+            except Exception as e:
+                raise ValueError(
+                    f"Could not extract number of masks from HDF5 file. Error: {e}"
+                ) from e
+
+    def _general_setup(self, extraction_dir: str, return_results: bool = False):
         """Helper function to execute all setup functions that are common to all featurization steps."""
 
-        self._setup_output()
+        self.extraction_file = extraction_dir
+        if not return_results:
+            self._setup_output()
+        self._get_nmasks()
         self._setup_log_transform()
         self._setup_inference_device()
 
@@ -784,8 +796,8 @@ def _setup_transforms(self) -> None:
 
         return
 
-    def _setup(self):
-        self._general_setup()
+    def _setup(self, extraction_dir: str, return_results: bool):
+        self._general_setup(extraction_dir=extraction_dir, return_results=return_results)
         self._get_model_specs()
         self._get_network_dir()
 
@@ -803,7 +815,7 @@ def _setup(self):
         self._setup_encoders()
         self._setup_transforms()
 
-    def process(self, extraction_dir: str, size: int = 0):
+    def process(self, extraction_dir: str, size: int = 0, return_results: bool = False):
         """
         Perform classification on the provided HDF5 dataset.
 
@@ -880,7 +892,7 @@ class based on the previous single-cell extraction. Therefore, only the second a
         self.log("Started MLClusterClassifier classification.")
 
         # perform setup
-        self._setup()
+        self._setup(extraction_dir = extraction_dir, return_results=return_results)
 
         self.dataloader = self.generate_dataloader(
             extraction_dir,
@@ -890,21 +902,28 @@ class based on the previous single-cell extraction. Therefore, only the second a
         )
 
         # perform inference
+        all_results = []
         for model in self.models:
             self.log(f"Starting inference for model encoder {model.__name__}")
             results = self.inference(self.dataloader, model)
 
-            output_name = f"inference_{model.__name__}"
-            path = os.path.join(self.run_path, f"{output_name}.csv")
+            if not return_results:
+                output_name = f"inference_{model.__name__}"
+                path = os.path.join(self.run_path, f"{output_name}.csv")
 
-            self._write_results_csv(results, path)
-            self._write_results_sdata(results, label=f"{self.label}_{model.__name__}")
-
-        self.log(f"Results saved to file: {path}")
+                self._write_results_csv(results, path)
+                self._write_results_sdata(results, label=f"{self.label}_{model.__name__}")
+            else:
+                all_results.append(results)
 
-        # perform post processing cleanup
-        if not self.deep_debug:
-            self._post_processing_cleanup()
+        if return_results:
+            self._clear_cache()
+            return all_results
+        else:
+            self.log(f"Results saved to file: {path}")
+            # perform post processing cleanup
+            if not self.deep_debug:
+                self._post_processing_cleanup()
 
 
 class EnsembleClassifier(_FeaturizationBase):
@@ -956,8 +975,8 @@ def _load_models(self):
         memory_usage = self._get_gpu_memory_usage()
         self.log(f"GPU memory usage after loading models: {memory_usage}")
 
-    def _setup(self):
-        self._general_setup()
+    def _setup(self, extraction_dir: str):
+        self._general_set(extraction_dir=extraction_dir)
         self._get_model_specs()
         self._setup_transforms()
 
@@ -969,7 +988,7 @@ def _setup(self):
 
         self._load_models()
 
-    def process(self, extraction_dir, size=0):
+    def process(self, extraction_dir:str, size:int = 0, return_results:bool = False):
         """
         Function called to perform classification on the provided HDF5 dataset.
 
@@ -1024,7 +1043,7 @@ class based on the previous single-cell extraction. Therefore, no parameters nee
 
         self.log("Starting Ensemble Classification")
 
-        self._setup()
+        self._setup(extraction_dir=extraction_dir, return_results=return_results)
 
         self.dataloader = self.generate_dataloader(
             extraction_dir,
@@ -1034,19 +1053,28 @@ class based on the previous single-cell extraction. Therefore, no parameters nee
         )
 
         # perform inference
+        all_results = {}
         for model_name, model in zip(self.model_names, self.model, strict=False):
             self.log(f"Starting inference for model {model_name}")
             results = self.inference(self.dataloader, model)
 
             output_name = f"ensemble_inference_{model_name}"
-            path = os.path.join(self.run_path, f"{output_name}.csv")
 
-            self._write_results_csv(results, path)
-            self._write_results_sdata(results, label=model_name)
+            if not return_results:
+                path = os.path.join(self.run_path, f"{output_name}.csv")
 
-        # perform post processing cleanup
-        if not self.deep_debug:
-            self._post_processing_cleanup()
+                self._write_results_csv(results, path)
+                self._write_results_sdata(results, label=model_name)
+            else:
+                all_results[model_name] = results
+
+        if return_results:
+            self._clear_cache()
+            return all_results
+        else:
+            # perform post processing cleanup
+            if not self.deep_debug:
+                self._post_processing_cleanup()
 
 
 ####### CellFeaturization based on Classic Featurecalculation #######
@@ -1083,10 +1111,19 @@ def _setup_transforms(self):
         return
 
     def _get_channel_specs(self):
-        if "channel_names" in self.project.__dict__.keys():
-            self.channel_names = self.project.channel_names
+        if self.project is None:
+            try:
+                with h5py.File(self.extraction_file, "r") as f:
+                    self.channel_names = list(f["channel_information"][:].astype(str))
+            except Exception as e:
+                raise ValueError(
+                    f"Could not extract channel names from HDF5 file. Please provide channel names manually. Error: {e}"
+                ) from e
         else:
-            self.channel_names = self.project.input_image.c.values
+            if "channel_names" in self.project.__dict__.keys():
+                self.channel_names = self.project.channel_names
+            else:
+                self.channel_names = self.project.input_image.c.values
 
     def _generate_column_names(
         self,
@@ -1298,12 +1335,12 @@ def __init__(self, *args, **kwargs):
 
         self.channel_selection = None  # ensure that all images are passed to the function
 
-    def _setup(self):
-        self._general_setup()
+    def _setup(self, extraction_dir:str, return_results:bool):
+        self._general_setup(extraction_dir=extraction_dir, return_results=return_results)
         self._setup_transforms()
         self._get_channel_specs()
 
-    def process(self, extraction_dir, size=0):
+    def process(self, extraction_dir: str, size: int =0, return_results: bool = False):
         """
         Perform featurization on the provided HDF5 dataset.
 
@@ -1358,7 +1395,7 @@ def process(self, extraction_dir, size=0):
         self.log("Started CellFeaturization of all available channels.")
 
         # perform setup
-        self._setup()
+        self._setup(extraction_dir=extraction_dir, return_results=return_results)
 
         self.dataloader = self.generate_dataloader(
             extraction_dir,
@@ -1388,15 +1425,19 @@ def process(self, extraction_dir, size=0):
             column_names=self.column_names,
         )
 
-        output_name = "calculated_image_features"
-        path = os.path.join(self.run_path, f"{output_name}.csv")
+        if return_results:
+            self._clear_cache()
+            return results
+        else:
+            output_name = "calculated_image_features"
+            path = os.path.join(self.run_path, f"{output_name}.csv")
 
-        self._write_results_csv(results, path)
-        self._write_results_sdata(results)
+            self._write_results_csv(results, path)
+            self._write_results_sdata(results)
 
-        # perform post processing cleanup
-        if not self.deep_debug:
-            self._post_processing_cleanup()
+            # perform post processing cleanup
+            if not self.deep_debug:
+                self._post_processing_cleanup()
 
 
 class CellFeaturizer_single_channel(_cellFeaturizerBase):
@@ -1412,17 +1453,17 @@ def _setup_channel_selection(self):
             self.channel_selection = [0, self.channel_selection]
         return
 
-    def _setup(self):
-        self._general_setup()
+    def _setup(self, extraction_dir:str, return_results:bool):
+        self._general_setup(extraction_dir=extraction_dir, return_results=return_results)
         self._setup_channel_selection()
         self._setup_transforms()
         self._get_channel_specs()
 
-    def process(self, extraction_dir, size=0):
+    def process(self, extraction_dir, size=0, return_results: bool = False):
         self.log(f"Started CellFeaturization of selected channel {self.channel_selection}.")
 
         # perform setup
-        self._setup()
+        self._setup(extraction_dir=extraction_dir, return_results=return_results)
 
         self.dataloader = self.generate_dataloader(
             extraction_dir,