Skip to content

Commit

Permalink
read masks from hdf5 if not already provided
Browse files Browse the repository at this point in the history
  • Loading branch information
sophiamaedler committed Jan 27, 2025
1 parent 07098c4 commit cde2998
Showing 1 changed file with 85 additions and 44 deletions.
129 changes: 85 additions & 44 deletions src/scportrait/pipeline/featurization.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from functools import partial as func_partial

import numpy as np
import h5py
import pandas as pd
import pytorch_lightning as pl
import torch
Expand All @@ -17,7 +18,6 @@
from scportrait.tools.ml.datasets import HDF5SingleCellDataset
from scportrait.tools.ml.plmodels import MultilabelSupervisedModel


class _FeaturizationBase(ProcessingStep):
PRETRAINED_MODEL_NAMES = [
"autophagy_classifier",
Expand Down Expand Up @@ -170,10 +170,22 @@ def _setup_inference_device(self):
self.inference_device = self._detect_automatic_inference_device()
self.log(f"Automatically configured inferece device to {self.inference_device}")

def _general_setup(self):
def _get_nmasks(self):
if "n_masks" not in self.__dict__.keys():
try:
self.n_masks = h5py.File(self.extraction_file, "r")["n_masks"][()].item()
except Exception as e:
raise ValueError(
f"Could not extract number of masks from HDF5 file. Error: {e}"
) from e

def _general_setup(self, extraction_dir: str, return_results: bool = False):
"""Helper function to execute all setup functions that are common to all featurization steps."""

self._setup_output()
self.extraction_file = extraction_dir
if not return_results:
self._setup_output()
self._get_nmasks()
self._setup_log_transform()
self._setup_inference_device()

Expand Down Expand Up @@ -784,8 +796,8 @@ def _setup_transforms(self) -> None:

return

def _setup(self):
self._general_setup()
def _setup(self, extraction_dir: str, return_results: bool):
self._general_setup(extraction_dir=extraction_dir, return_results=return_results)
self._get_model_specs()
self._get_network_dir()

Expand All @@ -803,7 +815,7 @@ def _setup(self):
self._setup_encoders()
self._setup_transforms()

def process(self, extraction_dir: str, size: int = 0):
def process(self, extraction_dir: str, size: int = 0, return_results: bool = False):
"""
Perform classification on the provided HDF5 dataset.
Expand Down Expand Up @@ -880,7 +892,7 @@ class based on the previous single-cell extraction. Therefore, only the second a
self.log("Started MLClusterClassifier classification.")

# perform setup
self._setup()
self._setup(extraction_dir = extraction_dir, return_results=return_results)

self.dataloader = self.generate_dataloader(
extraction_dir,
Expand All @@ -890,21 +902,28 @@ class based on the previous single-cell extraction. Therefore, only the second a
)

# perform inference
all_results = []
for model in self.models:
self.log(f"Starting inference for model encoder {model.__name__}")
results = self.inference(self.dataloader, model)

output_name = f"inference_{model.__name__}"
path = os.path.join(self.run_path, f"{output_name}.csv")
if not return_results:
output_name = f"inference_{model.__name__}"
path = os.path.join(self.run_path, f"{output_name}.csv")

self._write_results_csv(results, path)
self._write_results_sdata(results, label=f"{self.label}_{model.__name__}")

self.log(f"Results saved to file: {path}")
self._write_results_csv(results, path)
self._write_results_sdata(results, label=f"{self.label}_{model.__name__}")
else:
all_results.append(results)

# perform post processing cleanup
if not self.deep_debug:
self._post_processing_cleanup()
if return_results:
self._clear_cache()
return all_results
else:
self.log(f"Results saved to file: {path}")
# perform post processing cleanup
if not self.deep_debug:
self._post_processing_cleanup()


class EnsembleClassifier(_FeaturizationBase):
Expand Down Expand Up @@ -956,8 +975,8 @@ def _load_models(self):
memory_usage = self._get_gpu_memory_usage()
self.log(f"GPU memory usage after loading models: {memory_usage}")

def _setup(self):
self._general_setup()
def _setup(self, extraction_dir: str):
self._general_set(extraction_dir=extraction_dir)
self._get_model_specs()
self._setup_transforms()

Expand All @@ -969,7 +988,7 @@ def _setup(self):

self._load_models()

def process(self, extraction_dir, size=0):
def process(self, extraction_dir:str, size:int = 0, return_results:bool = False):
"""
Function called to perform classification on the provided HDF5 dataset.
Expand Down Expand Up @@ -1024,7 +1043,7 @@ class based on the previous single-cell extraction. Therefore, no parameters nee

self.log("Starting Ensemble Classification")

self._setup()
self._setup(extraction_dir=extraction_dir, return_results=return_results)

self.dataloader = self.generate_dataloader(
extraction_dir,
Expand All @@ -1034,19 +1053,28 @@ class based on the previous single-cell extraction. Therefore, no parameters nee
)

# perform inference
all_results = {}
for model_name, model in zip(self.model_names, self.model, strict=False):
self.log(f"Starting inference for model {model_name}")
results = self.inference(self.dataloader, model)

output_name = f"ensemble_inference_{model_name}"
path = os.path.join(self.run_path, f"{output_name}.csv")

self._write_results_csv(results, path)
self._write_results_sdata(results, label=model_name)
if not return_results:
path = os.path.join(self.run_path, f"{output_name}.csv")

# perform post processing cleanup
if not self.deep_debug:
self._post_processing_cleanup()
self._write_results_csv(results, path)
self._write_results_sdata(results, label=model_name)
else:
all_results[model_name] = results

if return_results:
self._clear_cache()
return all_results
else:
# perform post processing cleanup
if not self.deep_debug:
self._post_processing_cleanup()


####### CellFeaturization based on Classic Featurecalculation #######
Expand Down Expand Up @@ -1083,10 +1111,19 @@ def _setup_transforms(self):
return

def _get_channel_specs(self):
if "channel_names" in self.project.__dict__.keys():
self.channel_names = self.project.channel_names
if self.project is None:
try:
with h5py.File(self.extraction_file, "r") as f:
self.channel_names = list(f["channel_information"][:].astype(str))
except Exception as e:
raise ValueError(
f"Could not extract channel names from HDF5 file. Please provide channel names manually. Error: {e}"
) from e
else:
self.channel_names = self.project.input_image.c.values
if "channel_names" in self.project.__dict__.keys():
self.channel_names = self.project.channel_names
else:
self.channel_names = self.project.input_image.c.values

def _generate_column_names(
self,
Expand Down Expand Up @@ -1298,12 +1335,12 @@ def __init__(self, *args, **kwargs):

self.channel_selection = None # ensure that all images are passed to the function

def _setup(self):
self._general_setup()
def _setup(self, extraction_dir:str, return_results:bool):
self._general_setup(extraction_dir=extraction_dir, return_results=return_results)
self._setup_transforms()
self._get_channel_specs()

def process(self, extraction_dir, size=0):
def process(self, extraction_dir: str, size: int =0, return_results: bool = False):
"""
Perform featurization on the provided HDF5 dataset.
Expand Down Expand Up @@ -1358,7 +1395,7 @@ def process(self, extraction_dir, size=0):
self.log("Started CellFeaturization of all available channels.")

# perform setup
self._setup()
self._setup(extraction_dir=extraction_dir, return_results=return_results)

self.dataloader = self.generate_dataloader(
extraction_dir,
Expand Down Expand Up @@ -1388,15 +1425,19 @@ def process(self, extraction_dir, size=0):
column_names=self.column_names,
)

output_name = "calculated_image_features"
path = os.path.join(self.run_path, f"{output_name}.csv")
if return_results:
self._clear_cache()
return results
else:
output_name = "calculated_image_features"
path = os.path.join(self.run_path, f"{output_name}.csv")

self._write_results_csv(results, path)
self._write_results_sdata(results)
self._write_results_csv(results, path)
self._write_results_sdata(results)

# perform post processing cleanup
if not self.deep_debug:
self._post_processing_cleanup()
# perform post processing cleanup
if not self.deep_debug:
self._post_processing_cleanup()


class CellFeaturizer_single_channel(_cellFeaturizerBase):
Expand All @@ -1412,17 +1453,17 @@ def _setup_channel_selection(self):
self.channel_selection = [0, self.channel_selection]
return

def _setup(self):
self._general_setup()
def _setup(self, extraction_dir:str, return_results:bool):
self._general_setup(extraction_dir=extraction_dir, return_results=return_results)
self._setup_channel_selection()
self._setup_transforms()
self._get_channel_specs()

def process(self, extraction_dir, size=0):
def process(self, extraction_dir, size=0, return_results: bool = False):
self.log(f"Started CellFeaturization of selected channel {self.channel_selection}.")

# perform setup
self._setup()
self._setup(extraction_dir=extraction_dir, return_results=return_results)

self.dataloader = self.generate_dataloader(
extraction_dir,
Expand Down

0 comments on commit cde2998

Please sign in to comment.