Skip to content

Commit

Permalink
Merge pull request #145 from MannLabs/development
Browse files Browse the repository at this point in the history
merge development branch
sophiamaedler authored Jan 31, 2025

Verified

This commit was signed with the committer’s verified signature.
2 parents b46b1e9 + 7391e85 commit 758e7c5
Showing 15 changed files with 511 additions and 153 deletions.
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -29,10 +29,10 @@ torch
pytorch-lightning
torchvision

spatialdata
spatialdata>=0.2.0
napari-spatialdata
pyqt5
lxml_html_clean
ashlar>=1.19.0
networkx
py-lmd @ git+https://github.com/MannLabs/py-lmd.git@refs/pull/11/head#egg=py-lmd
py-lmd>=1.3.1
4 changes: 2 additions & 2 deletions requirements_dev.txt
Original file line number Diff line number Diff line change
@@ -29,13 +29,13 @@ torch
pytorch-lightning
torchvision

spatialdata
spatialdata>=0.2.0
napari-spatialdata
pyqt5
lxml_html_clean
ashlar>=1.19.0
networkx
py-lmd @ git+https://github.com/MannLabs/py-lmd.git@refs/pull/11/head#egg=py-lmd
py-lmd>=1.3.1

#packages for building the documentation
sphinx
11 changes: 11 additions & 0 deletions src/scportrait/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,18 @@
"""Top-level package for scPortrait"""

# silence warnings
import warnings

from scportrait import io
from scportrait import pipeline as pipeline
from scportrait import plotting as pl
from scportrait import processing as pp
from scportrait import tools as tl

# silence warning from spatialdata resulting in an older dask version see #139
warnings.filterwarnings("ignore", message="ignoring keyword argument 'read_only'")

# silence warning from cellpose resulting in missing parameter set in model call see #141
warnings.filterwarnings(
"ignore", message=r"You are using `torch.load` with `weights_only=False`.*", category=FutureWarning
)
59 changes: 52 additions & 7 deletions src/scportrait/pipeline/_base.py
Original file line number Diff line number Diff line change
@@ -9,6 +9,8 @@
import numpy as np
import torch

from scportrait.pipeline._utils.helper import read_config


class Logable:
"""Create log entries.
@@ -92,6 +94,27 @@ def _clean_log_file(self):
if os.path.exists(log_file_path):
os.remove(log_file_path)

# def _clear_cache(self, vars_to_delete=None):
# """Helper function to help clear memory usage. Mainly relevant for GPU based segmentations.

# Args:
# vars_to_delete (list): List of variable names (as strings) to delete.
# """

# # delete all specified variables
# if vars_to_delete is not None:
# for var_name in vars_to_delete:
# if var_name in globals():
# del globals()[var_name]

# if torch.cuda.is_available():
# torch.cuda.empty_cache()

# if torch.backends.mps.is_available():
# torch.mps.empty_cache()

# gc.collect()

def _clear_cache(self, vars_to_delete=None):
"""Helper function to help clear memory usage. Mainly relevant for GPU based segmentations."""

@@ -137,7 +160,7 @@ class ProcessingStep(Logable):
DEFAULT_SEGMENTATION_DIR_NAME = "segmentation"
DEFAULT_TILES_FOLDER = "tiles"

DEFAULT_EXTRACTIN_DIR_NAME = "extraction"
DEFAULT_EXTRACTION_DIR_NAME = "extraction"
DEFAULT_DATA_DIR = "data"

DEFAULT_IMAGE_DTYPE = np.uint16
@@ -155,19 +178,41 @@ class ProcessingStep(Logable):
DEFAULT_SELECTION_DIR_NAME = "selection"

def __init__(
self, config, directory, project_location, debug=False, overwrite=False, project=None, filehandler=None
self,
config,
directory=None,
project_location=None,
debug=False,
overwrite=False,
project=None,
filehandler=None,
from_project: bool = False,
):
super().__init__(directory=directory)

self.debug = debug
self.overwrite = overwrite
self.project_location = project_location
self.config = config
if from_project:
self.project_run = True
self.project_location = project_location
self.project = project
self.filehandler = filehandler
else:
self.project_run = False
self.project_location = None
self.project = None
self.filehandler = None

if isinstance(config, str):
config = read_config(config)
if self.__class__.__name__ in config.keys():
self.config = config[self.__class__.__name__]
else:
self.config = config
else:
self.config = config
self.overwrite = overwrite

self.project = project
self.filehandler = filehandler

self.get_context()

self.deep_debug = False
13 changes: 12 additions & 1 deletion src/scportrait/pipeline/_utils/helper.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,20 @@
from typing import TypeVar

import yaml

T = TypeVar("T")


def flatten(nested_list: list[list[T]]) -> list[T]:
def read_config(config_path: str) -> dict:
with open(config_path) as stream:
try:
config = yaml.safe_load(stream)
except yaml.YAMLError as exc:
print(exc)
return config


def flatten(nested_list: list[list[T]]) -> list[T | tuple[T]]:
"""Flatten a list of lists into a single list.
Args:
8 changes: 5 additions & 3 deletions src/scportrait/pipeline/_utils/sdata_io.py
Original file line number Diff line number Diff line change
@@ -71,10 +71,12 @@ def _read_sdata(self) -> SpatialData:
_sdata = SpatialData()
_sdata.write(self.sdata_path, overwrite=True)

allowed_labels = ["seg_all_nucleus", "seg_all_cytosol"]
for key in _sdata.labels:
segmentation_object = _sdata.labels[key]
if not hasattr(segmentation_object.attrs, "cell_ids"):
segmentation_object = spLabels2DModel().convert(segmentation_object, classes=None)
if key in allowed_labels:
segmentation_object = _sdata.labels[key]
if not hasattr(segmentation_object.attrs, "cell_ids"):
segmentation_object = spLabels2DModel().convert(segmentation_object, classes=None)

return _sdata

104 changes: 76 additions & 28 deletions src/scportrait/pipeline/extraction.py
Original file line number Diff line number Diff line change
@@ -57,7 +57,13 @@ def __init__(self, *args, **kwargs):
self.overwrite_run_path = self.overwrite

def _get_compression_type(self):
self.compression_type = "lzf" if self.compression else None
if (self.compression is True) or (self.compression == "lzf"):
self.compression_type = "lzf"
elif self.compression == "gzip":
self.compression_type = "gzip"
else:
self.compression_type = None
self.log(f"Compression algorithm: {self.compression_type}")
return self.compression_type

def _check_config(self):
@@ -261,24 +267,55 @@ def _get_segmentation_info(self):
f"Found no segmentation masks with key {self.segmentation_key}. Cannot proceed with extraction."
)

# get relevant segmentation masks to perform extraction on
nucleus_key = f"{self.segmentation_key}_nucleus"
# intialize default values to track what should be extracted
self.nucleus_key = None
self.cytosol_key = None
self.extract_nucleus_mask = False
self.extract_cytosol_mask = False

if nucleus_key in relevant_masks:
self.extract_nucleus_mask = True
self.nucleus_key = nucleus_key
else:
self.extract_nucleus_mask = False
self.nucleus_key = None
if "segmentation_mask" in self.config:
allowed_mask_values = ["nucleus", "cytosol"]
allowed_mask_values = [f"{self.segmentation_key}_{x}" for x in allowed_mask_values]

if isinstance(self.config["segmentation_mask"], str):
assert self.config["segmentation_mask"] in allowed_mask_values

cytosol_key = f"{self.segmentation_key}_cytosol"
if "nucleus" in self.config["segmentation_mask"]:
self.nucleus_key = self.config["segmentation_mask"]
self.extract_nucleus_mask = True

elif "cytosol" in self.config["segmentation_mask"]:
self.cytosol_key = self.config["segmentation_mask"]
self.extract_cytosol_mask = True
else:
raise ValueError(
f"Segmentation mask {self.config['segmentation_mask']} is not a valid mask to extract from."
)

elif isinstance(self.config["segmentation_mask"], list):
assert all(x in allowed_mask_values for x in self.config["segmentation_mask"])

for x in self.config["segmentation_mask"]:
if "nucleus" in x:
self.nucleus_key = x
self.extract_nucleus_mask = True
if "cytosol" in x:
self.cytosol_key = x
self.extract_cytosol_mask = True

if cytosol_key in relevant_masks:
self.extract_cytosol_mask = True
self.cytosol_key = cytosol_key
else:
self.extract_cytosol_mask = False
self.cytosol_key = None
# get relevant segmentation masks to perform extraction on
nucleus_key = f"{self.segmentation_key}_nucleus"

if nucleus_key in relevant_masks:
self.extract_nucleus_mask = True
self.nucleus_key = nucleus_key

cytosol_key = f"{self.segmentation_key}_cytosol"

if cytosol_key in relevant_masks:
self.extract_cytosol_mask = True
self.cytosol_key = cytosol_key

self.n_masks = np.sum([self.extract_nucleus_mask, self.extract_cytosol_mask])
self.masks = [x for x in [self.nucleus_key, self.cytosol_key] if x is not None]
@@ -415,7 +452,7 @@ def _save_removed_classes(self, classes):
# define path where classes should be saved
filtered_path = os.path.join(
self.project_location,
self.DEFAULT_SEGMENTATION_DIR_NAME,
self.DEFAULT_EXTRACTION_DIR_NAME,
self.DEFAULT_REMOVED_CLASSES_FILE,
)

@@ -636,7 +673,7 @@ def _transfer_tempmmap_to_hdf5(self):
axs[i].imshow(img, vmin=0, vmax=1)
axs[i].axis("off")
fig.tight_layout()
fig.show()
plt.show(fig)

self.log("Transferring extracted single cells to .hdf5")

@@ -651,7 +688,8 @@ def _transfer_tempmmap_to_hdf5(self):
) # increase to 64 bit otherwise information may become truncated

self.log("single-cell index created.")
self._clear_cache(vars_to_delete=[cell_ids])
del cell_ids
# self._clear_cache(vars_to_delete=[cell_ids]) # this is not working as expected so we will just delete the variable directly

_, c, x, y = _tmp_single_cell_data.shape
single_cell_data = hf.create_dataset(
@@ -668,7 +706,8 @@ def _transfer_tempmmap_to_hdf5(self):
single_cell_data[ix] = _tmp_single_cell_data[i]

self.log("single-cell data created")
self._clear_cache(vars_to_delete=[single_cell_data])
del single_cell_data
# self._clear_cache(vars_to_delete=[single_cell_data]) # this is not working as expected so we will just delete the variable directly

# also transfer labelled index to HDF5
index_labelled = _tmp_single_cell_index[keep_index]
@@ -684,18 +723,27 @@ def _transfer_tempmmap_to_hdf5(self):
hf.create_dataset("single_cell_index_labelled", data=index_labelled, chunks=None, dtype=dt)

self.log("single-cell index labelled created.")
self._clear_cache(vars_to_delete=[index_labelled])
del index_labelled
# self._clear_cache(vars_to_delete=[index_labelled]) # this is not working as expected so we will just delete the variable directly

hf.create_dataset(
"channel_information",
data=np.char.encode(self.channel_names.astype(str)),
dtype=h5py.special_dtype(vlen=str),
)

hf.create_dataset(
"n_masks",
data=self.n_masks,
dtype=int,
)

self.log("channel information created.")

# cleanup memory
self._clear_cache(vars_to_delete=[_tmp_single_cell_index, index_labelled])
del _tmp_single_cell_index
# self._clear_cache(vars_to_delete=[_tmp_single_cell_index]) # this is not working as expected so we will just delete the variable directly

os.remove(self._tmp_single_cell_data_path)
os.remove(self._tmp_single_cell_index_path)

@@ -808,7 +856,6 @@ def process(self, partial=False, n_cells=None, seed=42):
# directory where intermediate results should be saved
cache: "/mnt/temp/cache"
"""

total_time_start = timeit.default_timer()

start_setup = timeit.default_timer()
@@ -871,31 +918,33 @@ def process(self, partial=False, n_cells=None, seed=42):

self.log("Running in single threaded mode.")
results = []
for arg in tqdm(args):
for arg in tqdm(args, total=len(args), desc="Processing cell batches"):
x = f(arg)
results.append(x)
else:
# set up function for multi-threaded processing
f = func_partial(self._extract_classes_multi, self.px_centers)
batched_args = self._generate_batched_args(args)
args = self._generate_batched_args(args)

self.log(f"Running in multiprocessing mode with {self.threads} threads.")
with mp.get_context("fork").Pool(
processes=self.threads
) as pool: # both spawn and fork work but fork is faster so forcing fork here
results = list(
tqdm(
pool.imap(f, batched_args),
total=len(batched_args),
pool.imap(f, args),
total=len(args),
desc="Processing cell batches",
)
)
pool.close()
pool.join()
print("multiprocessing done.")

self.save_index_to_remove = flatten(results)

# cleanup memory and remove any no longer required variables
del results, args
# self._clear_cache(vars_to_delete=["results", "args"]) # this is not working as expected at the moment so need to manually delete the variables
stop_extraction = timeit.default_timer()

# calculate duration
@@ -912,7 +961,6 @@ def process(self, partial=False, n_cells=None, seed=42):
self.DEFAULT_LOG_NAME = "processing.log" # change log name back to default

self._post_extraction_cleanup()

total_time_stop = timeit.default_timer()
total_time = total_time_stop - total_time_start

173 changes: 126 additions & 47 deletions src/scportrait/pipeline/featurization.py
Original file line number Diff line number Diff line change
@@ -5,6 +5,7 @@
from contextlib import redirect_stdout
from functools import partial as func_partial

import h5py
import numpy as np
import pandas as pd
import pytorch_lightning as pl
@@ -36,6 +37,7 @@ def __init__(self, *args, **kwargs):
self.model = None
self.transforms = None
self.expected_imagesize = None
self.data_type = None

self._setup_channel_selection()

@@ -59,7 +61,10 @@ def _setup_output(self):
if not os.path.isdir(self.directory):
os.makedirs(self.directory)

self.run_path = os.path.join(self.directory, f"{self.data_type}_{self.label}")
if self.data_type is None:
self.run_path = os.path.join(self.directory, self.label)
else:
self.run_path = os.path.join(self.directory, f"{self.data_type}_{self.label}")

if not os.path.isdir(self.run_path):
os.makedirs(self.run_path)
@@ -104,6 +109,25 @@ def _detect_automatic_inference_device(self):

return inference_device

def _get_nmasks(self):
if "n_masks" not in self.__dict__.keys():
if isinstance(self.extraction_file, str):
with h5py.File(self.extraction_file, "r") as f:
self.n_masks = f["n_masks"][()].item()
if isinstance(self.extraction_file, list):
n_masks = []
for file in self.extraction_file:
with h5py.File(file, "r") as f:
n_masks.append(f["n_masks"][()].item())
assert (
x == n_masks[0] for x in n_masks
), "number of masks are not consistent over all passed HDF5 files."
self.n_masks = n_masks[0]
try:
self.n_masks = h5py.File(self.extraction_file, "r")["n_masks"][()].item()
except Exception as e:
raise ValueError(f"Could not extract number of masks from HDF5 file. Error: {e}") from e

def _setup_inference_device(self):
"""
Configure the featurization run to use the specified inference device.
@@ -166,10 +190,13 @@ def _setup_inference_device(self):
self.inference_device = self._detect_automatic_inference_device()
self.log(f"Automatically configured inferece device to {self.inference_device}")

def _general_setup(self):
def _general_setup(self, extraction_dir: str | list[str], return_results: bool = False):
"""Helper function to execute all setup functions that are common to all featurization steps."""

self._setup_output()
self.extraction_file = extraction_dir
if not return_results:
self._setup_output()
self._get_nmasks()
self._setup_log_transform()
self._setup_inference_device()

@@ -391,7 +418,8 @@ def configure_transforms(self, selected_transforms: list):

def generate_dataloader(
self,
extraction_dir: str,
extraction_dir: str | list[str],
labels: int | list[int] = 0,
selected_transforms: transforms.Compose = transforms.Compose([]),
size: int = 0,
seed: int | None = 42,
@@ -428,11 +456,20 @@ def generate_dataloader(
self.log(f"Expected image size is set to {self.expected_imagesize}. Resizing images to this size.")
t = transforms.Compose([t, transforms.Resize(self.expected_imagesize)])

if isinstance(extraction_dir, list):
assert isinstance(labels, list), "If multiple directories are provided, multiple labels must be provided."
paths = extraction_dir
labels = labels
elif isinstance(extraction_dir, str):
assert isinstance(labels, int), "If only one directory is provided, only one label must be provided."
paths = [extraction_dir]
labels = [labels]

f = io.StringIO()
with redirect_stdout(f):
dataset = dataset_class(
dir_list=[extraction_dir],
dir_labels=[0],
dir_list=paths,
dir_labels=labels,
transform=t,
return_id=True,
select_channel=self.channel_selection,
@@ -780,8 +817,8 @@ def _setup_transforms(self) -> None:

return

def _setup(self):
self._general_setup()
def _setup(self, extraction_dir: str, return_results: bool):
self._general_setup(extraction_dir=extraction_dir, return_results=return_results)
self._get_model_specs()
self._get_network_dir()

@@ -799,7 +836,7 @@ def _setup(self):
self._setup_encoders()
self._setup_transforms()

def process(self, extraction_dir: str, size: int = 0):
def process(self, extraction_dir: str, labels: int | list[int] = 0, size: int = 0, return_results: bool = False):
"""
Perform classification on the provided HDF5 dataset.
@@ -876,31 +913,39 @@ class based on the previous single-cell extraction. Therefore, only the second a
self.log("Started MLClusterClassifier classification.")

# perform setup
self._setup()
self._setup(extraction_dir=extraction_dir, return_results=return_results)

self.dataloader = self.generate_dataloader(
extraction_dir,
labels=labels,
selected_transforms=self.transforms,
size=size,
dataset_class=self.DEFAULT_DATA_LOADER,
)

# perform inference
all_results = []
for model in self.models:
self.log(f"Starting inference for model encoder {model.__name__}")
results = self.inference(self.dataloader, model)

output_name = f"inference_{model.__name__}"
path = os.path.join(self.run_path, f"{output_name}.csv")
if not return_results:
output_name = f"inference_{model.__name__}"
path = os.path.join(self.run_path, f"{output_name}.csv")

self._write_results_csv(results, path)
self._write_results_sdata(results, label=f"{self.label}_{model.__name__}")

self.log(f"Results saved to file: {path}")
self._write_results_csv(results, path)
self._write_results_sdata(results, label=f"{self.label}_{model.__name__}")
else:
all_results.append(results)

# perform post processing cleanup
if not self.deep_debug:
self._post_processing_cleanup()
if return_results:
self._clear_cache()
return all_results
else:
self.log(f"Results saved to file: {path}")
# perform post processing cleanup
if not self.deep_debug:
self._post_processing_cleanup()


class EnsembleClassifier(_FeaturizationBase):
@@ -952,8 +997,8 @@ def _load_models(self):
memory_usage = self._get_gpu_memory_usage()
self.log(f"GPU memory usage after loading models: {memory_usage}")

def _setup(self):
self._general_setup()
def _setup(self, extraction_dir: str, return_results: bool):
self._general_setup(extraction_dir=extraction_dir, return_results=return_results)
self._get_model_specs()
self._setup_transforms()

@@ -965,7 +1010,7 @@ def _setup(self):

self._load_models()

def process(self, extraction_dir, size=0):
def process(self, extraction_dir: str, labels: int | list[int] = 0, size: int = 0, return_results: bool = False):
"""
Function called to perform classification on the provided HDF5 dataset.
@@ -1020,29 +1065,39 @@ class based on the previous single-cell extraction. Therefore, no parameters nee

self.log("Starting Ensemble Classification")

self._setup()
self._setup(extraction_dir=extraction_dir, return_results=return_results)

self.dataloader = self.generate_dataloader(
extraction_dir,
labels=labels,
selected_transforms=self.transforms,
size=size,
dataset_class=self.DEFAULT_DATA_LOADER,
)

# perform inference
all_results = {}
for model_name, model in zip(self.model_names, self.model, strict=False):
self.log(f"Starting inference for model {model_name}")
results = self.inference(self.dataloader, model)

output_name = f"ensemble_inference_{model_name}"
path = os.path.join(self.run_path, f"{output_name}.csv")

self._write_results_csv(results, path)
self._write_results_sdata(results, label=model_name)
if not return_results:
path = os.path.join(self.run_path, f"{output_name}.csv")

# perform post processing cleanup
if not self.deep_debug:
self._post_processing_cleanup()
self._write_results_csv(results, path)
self._write_results_sdata(results, label=model_name)
else:
all_results[model_name] = results

if return_results:
self._clear_cache()
return all_results
else:
# perform post processing cleanup
if not self.deep_debug:
self._post_processing_cleanup()


####### CellFeaturization based on Classic Featurecalculation #######
@@ -1079,10 +1134,24 @@ def _setup_transforms(self):
return

def _get_channel_specs(self):
if "channel_names" in self.project.__dict__.keys():
self.channel_names = self.project.channel_names
if self.project is None:
if isinstance(self.extraction_file, str):
with h5py.File(self.extraction_file, "r") as f:
self.channel_names = list(f["channel_information"][:].astype(str))
if isinstance(self.extraction_file, list):
channel_names = []
for file in self.extraction_file:
with h5py.File(file, "r") as f:
channel_names.append(list(f["channel_information"][:].astype(str)))
assert (
x == channel_names[0] for x in channel_names
), "Channel names are not consistent over all passed HDF5 files."
self.channel_names = channel_names[0]
else:
self.channel_names = self.project.input_image.c.values
if "channel_names" in self.project.__dict__.keys():
self.channel_names = self.project.channel_names
else:
self.channel_names = self.project.input_image.c.values

def _generate_column_names(
self,
@@ -1294,12 +1363,14 @@ def __init__(self, *args, **kwargs):

self.channel_selection = None # ensure that all images are passed to the function

def _setup(self):
self._general_setup()
def _setup(self, extraction_dir: str | list[str], return_results: bool):
self._general_setup(extraction_dir=extraction_dir, return_results=return_results)
self._setup_transforms()
self._get_channel_specs()

def process(self, extraction_dir, size=0):
def process(
self, extraction_dir: str | list[str], labels: int | list[int] = 0, size: int = 0, return_results: bool = False
):
"""
Perform featurization on the provided HDF5 dataset.
@@ -1354,10 +1425,11 @@ def process(self, extraction_dir, size=0):
self.log("Started CellFeaturization of all available channels.")

# perform setup
self._setup()
self._setup(extraction_dir=extraction_dir, return_results=return_results)

self.dataloader = self.generate_dataloader(
extraction_dir,
labels=labels,
selected_transforms=self.transforms,
size=size,
dataset_class=self.DEFAULT_DATA_LOADER,
@@ -1384,15 +1456,19 @@ def process(self, extraction_dir, size=0):
column_names=self.column_names,
)

output_name = "calculated_image_features"
path = os.path.join(self.run_path, f"{output_name}.csv")
if return_results:
self._clear_cache()
return results
else:
output_name = "calculated_image_features"
path = os.path.join(self.run_path, f"{output_name}.csv")

self._write_results_csv(results, path)
self._write_results_sdata(results)
self._write_results_csv(results, path)
self._write_results_sdata(results)

# perform post processing cleanup
if not self.deep_debug:
self._post_processing_cleanup()
# perform post processing cleanup
if not self.deep_debug:
self._post_processing_cleanup()


class CellFeaturizer_single_channel(_cellFeaturizerBase):
@@ -1408,20 +1484,23 @@ def _setup_channel_selection(self):
self.channel_selection = [0, self.channel_selection]
return

def _setup(self):
self._general_setup()
def _setup(self, extraction_dir: str | list[str], return_results: bool):
self._general_setup(extraction_dir=extraction_dir, return_results=return_results)
self._setup_channel_selection()
self._setup_transforms()
self._get_channel_specs()

def process(self, extraction_dir, size=0):
def process(
self, extraction_dir: str | list[str], labels: int | list[int] = 0, size=0, return_results: bool = False
):
self.log(f"Started CellFeaturization of selected channel {self.channel_selection}.")

# perform setup
self._setup()
self._setup(extraction_dir=extraction_dir, return_results=return_results)

self.dataloader = self.generate_dataloader(
extraction_dir,
labels=labels,
selected_transforms=self.transforms,
size=size,
dataset_class=self.DEFAULT_DATA_LOADER,
30 changes: 15 additions & 15 deletions src/scportrait/pipeline/project.py
Original file line number Diff line number Diff line change
@@ -23,7 +23,6 @@
import numpy as np
import psutil
import xarray
import yaml
from alphabase.io import tempmmap
from napari_spatialdata import Interactive
from ome_zarr.io import parse_url
@@ -33,6 +32,7 @@

from scportrait.io import daskmmap
from scportrait.pipeline._base import Logable
from scportrait.pipeline._utils.helper import read_config
from scportrait.pipeline._utils.sdata_io import sdata_filehandler
from scportrait.pipeline._utils.spatialdata_helper import (
calculate_centroids,
@@ -94,7 +94,7 @@ class Project(Logable):
def __init__(
self,
project_location: str,
config_path: str,
config_path: str = None,
segmentation_f=None,
extraction_f=None,
featurization_f=None,
@@ -185,11 +185,7 @@ def _load_config_from_file(self, file_path):
if not os.path.isfile(file_path):
raise ValueError(f"Your config path {file_path} is invalid.")

with open(file_path) as stream:
try:
self.config = yaml.safe_load(stream)
except yaml.YAMLError as exc:
print(exc)
self.config = read_config(file_path)

def _get_config_file(self, config_path: str | None = None) -> None:
"""Load the config file for the project. If no config file is passed the default config file in the project directory is loaded.
@@ -257,6 +253,7 @@ def _setup_segmentation_f(self, segmentation_f):
overwrite=self.overwrite,
project=None,
filehandler=self.filehandler,
from_project=True,
)

def _setup_extraction_f(self, extraction_f):
@@ -285,6 +282,7 @@ def _setup_extraction_f(self, extraction_f):
overwrite=self.overwrite,
project=self,
filehandler=self.filehandler,
from_project=True,
)

def _setup_featurization_f(self, featurization_f):
@@ -309,9 +307,10 @@ def _setup_featurization_f(self, featurization_f):
self.featurization_directory,
project_location=self.project_location,
debug=self.debug,
overwrite=self.overwrite,
overwrite=False, # this needs to be set to false as the featurization step should not remove previously created features
project=self,
filehandler=self.filehandler,
from_project=True,
)

def _setup_selection(self, selection_f):
@@ -339,6 +338,7 @@ def _setup_selection(self, selection_f):
overwrite=self.overwrite,
project=self,
filehandler=self.filehandler,
from_project=True,
)

def update_featurization_f(self, featurization_f):
@@ -888,9 +888,10 @@ def load_input_from_sdata(
# ensure that the provided nucleus and cytosol segmentations fullfill the scPortrait requirements
# requirements are:
# 1. The nucleus segmentation mask and the cytosol segmentation mask must contain the same ids
assert (
self.sdata[self.nuc_seg_name].attrs["cell_ids"] == self.sdata[self.cyto_seg_name].attrs["cell_ids"]
), "The nucleus segmentation mask and the cytosol segmentation mask must contain the same ids."
if self.nuc_seg_status in self.sdata.keys() and self.cyto_seg_status in self.sdata.keys():
assert (
self.sdata[self.nuc_seg_name].attrs["cell_ids"] == self.sdata[self.cyto_seg_name].attrs["cell_ids"]
), "The nucleus segmentation mask and the cytosol segmentation mask must contain the same ids."

# 2. the nucleus segmentation ids and the cytosol segmentation ids need to match
# THIS NEEDS TO BE IMPLEMENTED HERE
@@ -1066,6 +1067,8 @@ def featurize(
# setup overwrite if specified in call
if overwrite is not None:
self.featurization_f.overwrite_run_path = overwrite
if overwrite is None:
self.featurization_f.overwrite_run_path = True

# update the number of masks that are available in the segmentation object
self.featurization_f.n_masks = sum([self.nuc_seg_status, self.cyto_seg_status])
@@ -1079,7 +1082,6 @@ def select(
self,
cell_sets: list[dict],
calibration_marker: np.ndarray | None = None,
segmentation_name: str = "seg_all_nucleus",
name: str | None = None,
):
"""
@@ -1091,14 +1093,12 @@ def select(

self._check_sdata_status()

if not self.nuc_seg_status or not self.cyto_seg_status:
if not self.nuc_seg_status and not self.cyto_seg_status:
raise ValueError("No nucleus or cytosol segmentation loaded. Please load a segmentation first.")

assert self.sdata is not None, "No sdata object loaded."
assert segmentation_name in self.sdata.labels, f"Segmentation {segmentation_name} not found in sdata object."

self.selection_f(
segmentation_name=segmentation_name,
cell_sets=cell_sets,
calibration_marker=calibration_marker,
name=name,
8 changes: 5 additions & 3 deletions src/scportrait/pipeline/segmentation/segmentation.py
Original file line number Diff line number Diff line change
@@ -85,6 +85,7 @@ def __init__(
overwrite,
project,
filehandler,
from_project: bool = False,
**kwargs,
):
super().__init__(
@@ -95,6 +96,7 @@ def __init__(
overwrite=overwrite,
project=project,
filehandler=filehandler,
from_project=from_project,
)

if self.directory is not None:
@@ -742,7 +744,6 @@ def _resolve_sharding(self, sharding_plan):
local_hf = h5py.File(local_output, "r")
local_hdf_labels = local_hf.get(self.DEFAULT_MASK_NAME)[:]

print(type(local_hdf_labels))
shifted_map, edge_labels = shift_labels(
local_hdf_labels,
class_id_shift,
@@ -902,8 +903,9 @@ def _resolve_sharding(self, sharding_plan):
if not self.deep_debug:
self._cleanup_shards(sharding_plan)

def _initializer_function(self, gpu_id_list):
def _initializer_function(self, gpu_id_list, n_processes):
current_process().gpu_id_list = gpu_id_list
current_process().n_processes = n_processes

def _perform_segmentation(self, shard_list):
# get GPU status
@@ -921,7 +923,7 @@ def _perform_segmentation(self, shard_list):
with mp.get_context(self.context).Pool(
processes=self.n_processes,
initializer=self._initializer_function,
initargs=[self.gpu_id_list],
initargs=[self.gpu_id_list, self.n_processes],
) as pool:
list(
tqdm(
9 changes: 8 additions & 1 deletion src/scportrait/pipeline/segmentation/workflows.py
Original file line number Diff line number Diff line change
@@ -653,7 +653,11 @@ def _check_for_mask_matching_filtering(self) -> None:
else:
# add deprecation warning for old config setup
if "filter_status" in self.config.keys():
Warning("filter_status is deprecated, please use match_masks instead Will not perform filtering.")
self.filter_match_masks = True
self.mask_matching_filtering_threshold = 0.95
Warning(
"filter_status is deprecated, please use match_masks instead. Will use default settings for mask matching."
)

# default behaviour that this filtering should be performed, otherwise another additional step is required before extraction
self.filter_match_masks = True
@@ -1349,6 +1353,9 @@ def _check_gpu_status(self):
gpu_id_list = current.gpu_id_list
cpu_id = int(cpu_name[cpu_name.find("-") + 1 :]) - 1

if cpu_id >= len(gpu_id_list):
cpu_id = cpu_id % current.n_processes

# track gpu_id and update GPU status
self.gpu_id = gpu_id_list[cpu_id]
self.status = "multi_GPU"
200 changes: 171 additions & 29 deletions src/scportrait/pipeline/selection.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,20 @@
import multiprocessing as mp
import os
import pickle
import timeit
from functools import partial as func_partial

import h5py
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from alphabase.io import tempmmap
from lmd.lib import SegmentationLoader
from scipy.sparse import coo_array
from tqdm.auto import tqdm

from scportrait.pipeline._base import ProcessingStep
from scportrait.pipeline._utils.helper import flatten


class LMDSelection(ProcessingStep):
@@ -13,20 +23,60 @@ class LMDSelection(ProcessingStep):
This method class relies on the functionality of the pylmd library.
"""

# define all valid path optimization methods used with the "path_optimization" argument in the configuration
VALID_PATH_OPTIMIZERS = ["none", "hilbert", "greedy"]

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._check_config()

self.name = None
self.cell_sets = None
self.calibration_marker = None

def _setup_selection(self):
# set orientation transform
self.config["orientation_transform"] = np.array([[0, -1], [1, 0]])
self.deep_debug = False # flag for deep debugging by developers

def _check_config(self):
assert "segmentation_channel" in self.config, "segmentation_channel not defined in config"
self.segmentation_channel_to_select = self.config["segmentation_channel"]

# check for optional config parameters

# this defines how large the box mask around the center of a cell is for the coordinate extraction
# assumption is that all pixels belonging to each mask are within the box otherwise they will be cut off during cutting contour generation

if "cell_width" in self.config:
self.cell_radius = self.config["cell_width"]
else:
self.cell_radius = 100

if "threads" in self.config:
self.threads = self.config["threads"]
assert self.threads > 0, "threads must be greater than 0"
assert isinstance(self.threads, int), "threads must be an integer"
else:
self.threads = 10

if "batch_size_coordinate_extraction" in self.config:
self.batch_size = self.config["batch_size_coordinate_extraction"]
assert self.batch_size > 0, "batch_size_coordinate_extraction must be greater than 0"
assert isinstance(self.batch_size, int), "batch_size_coordinate_extraction must be an integer"
else:
self.batch_size = 100

if "orientation_transform" in self.config:
self.orientation_transform = self.config["orientation_transform"]
else:
self.orientation_transform = np.array([[0, -1], [1, 0]])
self.config["orientation_transform"] = (
self.orientation_transform
) # ensure its also in config so its passed on to the segmentation loader

if "processes_cell_sets" in self.config:
self.processes_cell_sets = self.config["processes_cell_sets"]
assert self.processes_cell_sets > 0, "processes_cell_sets must be greater than 0"
assert isinstance(self.processes_cell_sets, int), "processes_cell_sets must be an integer"
else:
self.processes_cell_sets = 1

def _setup_selection(self):
# configure name of extraction
if self.name is None:
try:
@@ -39,6 +89,111 @@ def _setup_selection(self):
savename = name.replace(" ", "_") + ".xml"
self.savepath = os.path.join(self.directory, savename)

# check that the segmentation label exists
assert (
self.segmentation_channel_to_select in self.project.filehandler.get_sdata()._shared_keys
), f"Segmentation channel {self.segmentation_channel_to_select} not found in sdata."

def __get_coords(
self, cell_ids: list, centers: list[tuple[int, int]], width: int = 60
) -> list[tuple[int, np.ndarray]]:
results = []

_sdata = self.project.filehandler.get_sdata()
for i, _id in enumerate(cell_ids):
values = centers[i]

x_start = np.max([int(values[0]) - width, 0])
y_start = np.max([int(values[1]) - width, 0])

x_end = x_start + width * 2
y_end = y_start + width * 2

_cropped = _sdata[self.segmentation_channel_to_select][
slice(x_start, x_end), slice(y_start, y_end)
].compute()

# optional plotting output for deep debugging
if self.deep_debug:
if self.threads == 1:
plt.figure()
plt.imshow(_cropped)
plt.show()
else:
raise ValueError("Deep debug is not supported with multiple threads.")

sparse = coo_array(_cropped == _id)

if (
0 in sparse.coords[0]
or 0 in sparse.coords[1]
or width * 2 - 1 in sparse.coords[0]
or width * 2 - 1 in sparse.coords[1]
):
Warning(
f"Cell {i} with id {_id} is potentially not fully contained in the bounding mask. Consider increasing the value for the 'cell_width' parameter in your config."
)

x = sparse.coords[0] + x_start
y = sparse.coords[1] + y_start

results.append((_id, np.array(list(zip(x, y, strict=True)))))

return results

def _get_coords_multi(self, width: int, arg: tuple[list[int], np.ndarray]) -> list[tuple[int, np.ndarray]]:
cell_ids, centers = arg
results = self.__get_coords(cell_ids, centers, width)
return results

def _get_coords(
self, cell_ids: list, centers: list[tuple[int, int]], width: int = 60, batch_size: int = 100, threads: int = 10
) -> dict[int, np.ndarray]:
# create batches
n_batches = int(np.ceil(len(cell_ids) / batch_size))
slices = [(i * batch_size, i * batch_size + batch_size) for i in range(n_batches - 1)]
slices.append(((n_batches - 1) * batch_size, len(cell_ids)))

batched_args = [(cell_ids[start:end], centers[start:end]) for start, end in slices]

f = func_partial(self._get_coords_multi, width)

if (
threads == 1
): # if only one thread is used, the function is called directly to avoid the overhead of multiprocessing
results = [f(arg) for arg in batched_args]
else:
with mp.get_context(self.context).Pool(processes=threads) as pool:
results = list(
tqdm(
pool.imap(f, batched_args),
total=len(batched_args),
desc="Processing cell batches",
)
)
pool.close()
pool.join()

results = flatten(results) # type: ignore
return dict(results) # type: ignore

def _get_cell_ids(self, cell_sets: list[dict]) -> list[int]:
cell_ids = []
for cell_set in cell_sets:
if "classes" in cell_set:
cell_ids.extend(cell_set["classes"])
else:
Warning(f"Cell set {cell_set['name']} does not contain any classes.")
return cell_ids

def _get_centers(self, cell_ids: list[int]) -> list[tuple[int, int]]:
_sdata = self.project.filehandler.get_sdata()
centers = _sdata["centers_cells"].compute()
centers = centers.loc[cell_ids, :]
return centers[
["y", "x"]
].values.tolist() # needs to be returned as yx to match the coordinate system as saved in spatialdataobjects

def _post_processing_cleanup(self, vars_to_delete: list | None = None):
if vars_to_delete is not None:
self._clear_cache(vars_to_delete=vars_to_delete)
@@ -51,7 +206,6 @@ def _post_processing_cleanup(self, vars_to_delete: list | None = None):

def process(
self,
segmentation_name: str,
cell_sets: list[dict],
calibration_marker: np.array,
name: str | None = None,
@@ -61,9 +215,9 @@ def process(
Under the hood this method relies on the pylmd library and utilizies its `SegmentationLoader` Class.
Args:
segmentation_name (str): Name of the segmentation to be used for shape generation in the sdata object.
cell_sets (list of dict): List of dictionaries containing the sets of cells which should be sorted into a single well. Mandatory keys for each dictionary are: name, classes. Optional keys are: well.
calibration_marker (numpy.array): Array of size ‘(3,2)’ containing the calibration marker coordinates in the ‘(row, column)’ format.
name (str, optional): Name of the output file. If not provided, the name will be generated based on the names of the cell sets or if also not specified set to "selected_cells".
Example:
@@ -77,7 +231,6 @@ def process(
# A numpy Array of shape (3, 2) should be passed.
calibration_marker = np.array([marker_0, marker_1, marker_2])
# Sets of cells can be defined by providing a name and a list of classes in a dictionary.
cells_to_select = [{"name": "dataset1", "classes": [1, 2, 3]}]
@@ -122,7 +275,7 @@ def process(
convolution_smoothing: 25
# fold reduction of datapoints for compression
poly_compression_factor: 30
rdp: 0.7
# Optimization of the cutting path inbetween shapes
# optimized paths improve the cutting time and the microscopes focus
@@ -160,32 +313,21 @@ def process(

self._setup_selection()

## TO Do
# check if classes and seglookup table already exist as pickle file
# if not create them
# else load them and proceed with selection

# load segmentation from hdf5
self.path_seg_mask = self.filehandler._load_seg_to_memmap(
[segmentation_name], tmp_dir_abs_path=self._tmp_dir_path
start_time = timeit.default_timer()
cell_ids = self._get_cell_ids(cell_sets)
centers = self._get_centers(cell_ids)
coord_index = self._get_coords(
cell_ids=cell_ids, centers=centers, width=self.cell_radius, batch_size=self.batch_size, threads=self.threads
)
self.log(f"Coordinate lookup index calculation took {timeit.default_timer() - start_time} seconds.")

segmentation = tempmmap.mmap_array_from_path(self.path_seg_mask)

# create segmentation loader
sl = SegmentationLoader(
config=self.config,
verbose=self.debug,
processes=self.config["processes_cell_sets"],
)

if len(segmentation.shape) == 3:
segmentation = np.squeeze(segmentation)
else:
raise ValueError(f"Segmentation shape is not correct. Expected 2D array, got {segmentation.shape}")

# get shape collections
shape_collection = sl(segmentation, self.cell_sets, self.calibration_marker)
shape_collection = sl(None, self.cell_sets, self.calibration_marker, coords_lookup=coord_index)

if self.debug:
shape_collection.plot(calibration=True)
@@ -196,4 +338,4 @@ def process(
self.log(f"Saved output at {self.savepath}")

# perform post processing cleanup
self._post_processing_cleanup(vars_to_delete=[shape_collection, sl, segmentation])
self._post_processing_cleanup(vars_to_delete=[shape_collection, sl, coord_index])
2 changes: 1 addition & 1 deletion src/scportrait/tools/ml/transforms.py
Original file line number Diff line number Diff line change
@@ -19,7 +19,7 @@ def __init__(self, choices=4, include_zero=True):
delta = (360 - angles[-1]) / 2
angles = angles + delta

self.choices = angles
self.choices = angles.tolist()

def __call__(self, tensor):
angle = random.choice(self.choices)
37 changes: 24 additions & 13 deletions src/scportrait/tools/stitch/_stitch.py
Original file line number Diff line number Diff line change
@@ -21,12 +21,7 @@
from scportrait.io.daskmmap import dask_array_from_path
from scportrait.processing.images._image_processing import rescale_image
from scportrait.tools.stitch._utils.ashlar_plotting import plot_edge_quality, plot_edge_scatter
from scportrait.tools.stitch._utils.filereaders import (
BioformatsReaderRescale,
FilePatternReaderRescale,
)
from scportrait.tools.stitch._utils.filewriters import write_ome_zarr, write_spatialdata, write_tif, write_xml
from scportrait.tools.stitch._utils.parallelized_ashlar import ParallelEdgeAligner, ParallelMosaic


class Stitcher:
@@ -65,7 +60,7 @@ def __init__(
do_intensity_rescale: bool | str = True,
rescale_range: tuple = (1, 99),
channel_order: list[str] = None,
reader_type=FilePatternReaderRescale,
reader_type="FilePatternReaderRescale",
orientation: dict = None,
plot_QC: bool = True,
overwrite: bool = False,
@@ -114,6 +109,7 @@ def __init__(

if orientation is None:
orientation = {"flip_x": False, "flip_y": True}

self.input_dir = input_dir
self.slidename = slidename
self.outdir = outdir
@@ -139,6 +135,10 @@ def __init__(
self.orientation = orientation
self.reader_type = reader_type

# workaround for lazy imports of module
if self.reader_type == "FilePatternReaderRescale":
self.reader_type = self.FilePatternReaderRescale

# workflow setup
self.plot_QC = plot_QC
self.overwrite = overwrite
@@ -158,10 +158,20 @@ def _lazy_imports(self):
from ashlar.reg import EdgeAligner, Mosaic
from ashlar.scripts.ashlar import process_axis_flip

from scportrait.tools.stitch._utils.filereaders import (
BioformatsReaderRescale,
FilePatternReaderRescale,
)
from scportrait.tools.stitch._utils.parallelized_ashlar import ParallelEdgeAligner, ParallelMosaic

self.ashlar_thumbnail = thumbnail
self.ashlar_EdgeAligner = EdgeAligner
self.ashlar_Mosaic = Mosaic
self.ashlar_process_axis_flip = process_axis_flip
self.BioformatsReaderRescale = BioformatsReaderRescale
self.FilePatternReaderRescale = FilePatternReaderRescale
self.ParallelEdgeAligner = ParallelEdgeAligner
self.ParallelMosaic = ParallelMosaic

def __exit__(self):
self._clear_cache()
@@ -294,14 +304,14 @@ def _initialize_reader(self):
"""
Initialize the reader for reading image tiles.
"""
if self.reader_type == FilePatternReaderRescale:
if self.reader_type == self.FilePatternReaderRescale:
self.reader = self.reader_type(
self.input_dir,
self.pattern,
self.overlap,
rescale_range=self.rescale_range,
)
elif self.reader_type == BioformatsReaderRescale:
elif self.reader_type == self.BioformatsReaderRescale:
self.reader = self.reader_type(self.input_dir, rescale_range=self.rescale_range)

# setup correct orientation of slide (this depends on microscope used to generate the data)
@@ -564,7 +574,7 @@ class ParallelStitcher(Stitcher):
do_intensity_rescale (bool or "full_image", optional): Flag to indicate whether to rescale image intensities (default is True). Alternatively, set to "full_image" to rescale the entire image.
rescale_range (tuple or dict, optional): If all channels should be rescaled to the same range pass a tuple with the percentiles for rescaling (default is (1, 99)). Alternatively, a dictionary can be passed with the channel names as keys and the percentiles as values if each channel should be rescaled to a different range.
channel_order (list, optional): Order of channels in the generated output mosaic. If none (default value) the order of the channels is left unchanged.
reader_type (class, optional): Type of reader to use for reading image tiles (default is FilePatternReaderRescale).
reader_type (class, optional): Type of reader to use for reading image tiles (default is "FilePatternReaderRescale").
orientation (dict, optional): Dictionary specifying which dimensions of the slide to flip (default is {'flip_x': False, 'flip_y': True}).
plot_QC (bool, optional): Flag to indicate whether to plot quality control (QC) figures (default is True).
overwrite (bool, optional): Flag to indicate whether to overwrite the output directory if it already exists (default is False).
@@ -588,7 +598,7 @@ def __init__(
WGAchannel: str = None,
channel_order: list[str] = None,
overwrite: bool = False,
reader_type=FilePatternReaderRescale,
reader_type="FilePatternReaderRescale",
orientation=None,
cache: str = None,
threads: int = 20,
@@ -613,8 +623,9 @@ def __init__(
overwrite,
cache,
)

# dirty fix to avoide multithreading error with BioformatsReader until this can be fixed
if self.reader_type == BioformatsReaderRescale:
if self.reader_type == self.BioformatsReaderRescale:
threads = 1
print(
"BioformatsReaderRescale does not support multithreading for calculating the error threshold currently. Proceeding with 1 thread."
@@ -632,7 +643,7 @@ def _initialize_aligner(self):
Returns:
aligner (ParallelEdgeAligner): Initialized ParallelEdgeAligner object.
"""
aligner = ParallelEdgeAligner(
aligner = self.ParallelEdgeAligner(
self.reader,
channel=self.stitching_channel_id,
filter_sigma=self.filter_sigma,
@@ -644,7 +655,7 @@ def _initialize_aligner(self):
return aligner

def _initialize_mosaic(self):
mosaic = ParallelMosaic(
mosaic = self.ParallelMosaic(
self.aligner, self.aligner.mosaic_shape, verbose=True, channels=self.channels, n_threads=self.threads
)
return mosaic

0 comments on commit 758e7c5

Please sign in to comment.