Skip to content

Commit

Permalink
Merge pull request #744 from computational-cell-analytics/dev
Browse files Browse the repository at this point in the history
Changes for release 1.1
  • Loading branch information
constantinpape authored Oct 18, 2024
2 parents 35cb739 + 6330b72 commit a9a8003
Show file tree
Hide file tree
Showing 92 changed files with 6,153 additions and 785 deletions.
6 changes: 5 additions & 1 deletion .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ jobs:
fail-fast: false
matrix:
os: [ubuntu-latest, windows-latest, macos-latest]
python-version: ["3.11", "3.12"]
# 3.12 currently not supported due to issues with nifty.
# python-version: ["3.11", "3.12"]
python-version: ["3.11"]

steps:
- name: Checkout
Expand All @@ -30,6 +32,8 @@ jobs:
uses: mamba-org/setup-micromamba@v1
with:
environment-file: environment_cpu.yaml
create-args: >-
python=${{ matrix.python-version }}
# Setup Qt libraries for GUI testing on Linux
- uses: tlambert03/setup-qt-libs@v1
Expand Down
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -177,3 +177,6 @@ cython_debug/
# Torch-em stuff
checkpoints/
logs/

# "gpu_jobs" folder where slurm batch submission scripts are saved
gpu_jobs/
81 changes: 81 additions & 0 deletions development/check_3d_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import numpy as np
import torch
import micro_sam.util as util

from micro_sam.sam_3d_wrapper import get_3d_sam_model
from micro_sam.training.semantic_sam_trainer import SemanticSamTrainer3D


def predict_3d_model():
d_size = 8
device = "cuda" if torch.cuda.is_available() else "cpu"
sam_3d = get_3d_sam_model(device, d_size)

input_ = 255 * np.random.rand(1, d_size, 3, 1024, 1024).astype("float32")
with torch.no_grad():
input_ = torch.from_numpy(input_).to(device)
out = sam_3d(input_, multimask_output=False, image_size=1024)
print(out["masks"].shape)


class DummyDataset(torch.utils.data.Dataset):
def __init__(self, patch_shape, n_classes):
self.patch_shape = patch_shape
self.n_classes = n_classes

def __len__(self):
return 5

def __getitem__(self, index):
image_shape = (self.patch_shape[0], 3) + self.patch_shape[1:]
x = np.random.rand(*image_shape).astype("float32")
label_shape = (self.n_classes,) + self.patch_shape
y = (np.random.rand(*label_shape) > 0.5).astype("float32")
return x, y


def get_loader(patch_shape, n_classes, batch_size):
ds = DummyDataset(patch_shape, n_classes)
loader = torch.utils.data.DataLoader(ds, batch_size=batch_size, shuffle=True, num_workers=4)
loader.shuffle = True
return loader


# TODO: we are missing the resizing in the model, so currently this only supports 1024x1024
def train_3d_model():
from micro_sam.training.util import ConvertToSemanticSamInputs

d_size = 4
n_classes = 5
batch_size = 2
image_size = 512

device = "cuda" if torch.cuda.is_available() else "cpu"
sam_3d = get_3d_sam_model(device, n_classes=n_classes, image_size=image_size)

train_loader = get_loader((d_size, image_size, image_size), n_classes, batch_size)
val_loader = get_loader((d_size, image_size, image_size), n_classes, batch_size)

optimizer = torch.optim.AdamW(sam_3d.parameters(), lr=5e-5)

trainer = SemanticSamTrainer3D(
name="test-sam",
model=sam_3d,
convert_inputs=ConvertToSemanticSamInputs(),
num_classes=n_classes,
train_loader=train_loader,
val_loader=val_loader,
optimizer=optimizer,
device=device,
compile_model=False,
)
trainer.fit(10)


def main():
# predict_3d_model()
train_3d_model()


if __name__ == "__main__":
main()
154 changes: 154 additions & 0 deletions development/instance_segmentation_3d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
import napari
from elf.io import open_file
import h5py
import os
import torch
import numpy as np

import micro_sam.sam_3d_wrapper as sam_3d
import micro_sam.util as util
# from micro_sam.segment_instances import (
# segment_instances_from_embeddings,
# segment_instances_sam,
# segment_instances_from_embeddings_3d,
# )
from micro_sam import multi_dimensional_segmentation as mds
from micro_sam.visualization import compute_pca
INPUT_PATH_CLUSTER = "/scratch-grete/projects/nim00007/data/mitochondria/cooper/mito_tomo/outer-membrane1/1_20230125_TOMO_HOI_WT_36859_J2_upSTEM750_BC3.6/upSTEM750_36859_J2_TS_SP_003_rec_2kb1dawbp_crop.h5"
# EMBEDDINGS_PATH_CLUSTER = "/scratch-grete/projects/nim00007/data/mitochondria/cooper/mito_tomo/outer-membrane1/1_20230125_TOMO_HOI_WT_36859_J2_upSTEM750_BC3.6/embedding-mito-3d.zarr"
EMBEDDINGS_PATH_CLUSTER = "/scratch-grete/usr/nimlufre/"
INPUT_PATH_LOCAL = "/home/freckmann15/data/mitochondria/cooper/mito_tomo/outer-membrane1/1_20230125_TOMO_HOI_WT_36859_J2_upSTEM750_BC3.6/upSTEM750_36859_J2_TS_SP_003_rec_2kb1dawbp_crop.h5"
EMBEDDINGS_PATH_LOCAL = "/home/freckmann15/data/mitochondria/cooper/mito_tomo/outer-membrane1/1_20230125_TOMO_HOI_WT_36859_J2_upSTEM750_BC3.6/"
INPUT_PATH = "/scratch-grete/projects/nim00007/data/mitochondria/moebius/volume_em/training_blocks_v1/4007_cutout_1.h5"
EMBEDDINGS_PATH = "/scratch-grete/projects/nim00007/data/mitochondria/moebius/volume_em/training_blocks_v1/embedding-mito-3d.zarr"
TIMESERIES_PATH = "../examples/data/DIC-C2DH-HeLa/train/01"
EMBEDDINGS_TRACKING_PATH = "../examples/embeddings/embeddings-ctc.zarr"

# def cell_segmentation_3d() -> None:
# with open_file(TIMESERIES_PATH, mode="r") as f:
# timeseries = f["*.tif"][:50]

# predictor = util.get_sam_model()
# image_embeddings = util.precompute_image_embeddings(predictor, timeseries, EMBEDDINGS_TRACKING_PATH)

# seg = segment_instances_from_embeddings_3d(predictor, image_embeddings)

# v = napari.Viewer()
# v.add_image(timeseries)
# v.add_labels(seg)
# napari.run()


# def _get_dataset_and_reshape(path: str, key: str = "raw", shape: tuple = (32, 256, 256)) -> np.ndarray:

# with h5py.File(path, "r") as f:
# # Check if the key exists in the file
# if key not in f:
# raise KeyError(f"Dataset with key '{key}' not found in file '{path}'.")

# # Load the dataset
# dataset = f[key][...]

# # Reshape the dataset
# if dataset.shape != shape:
# try:
# # Attempt to reshape the dataset to the desired shape
# dataset = dataset.reshape(shape)
# except ValueError:
# raise ValueError(f"Failed to reshape dataset with key '{key}' to shape {shape}.")

# return dataset
def get_dataset_cutout(path: str, key: str = "raw", shape: tuple = (32, 256, 256),
start_index: tuple = (0, 0, 0)) -> np.ndarray:
"""
Loads a cutout from a dataset in an HDF5 file.
Args:
path (str): Path to the HDF5 file.
key (str, optional): Key of the dataset to load. Defaults to "raw".
shape (tuple, optional): Desired shape of the cutout. Defaults to (32, 256, 256).
start_index (tuple, optional): Starting index for the cutout within the dataset.
Defaults to None, which selects a random starting point within valid bounds.
Returns:
np.ndarray: The loaded cutout of the dataset with the specified shape.
Raises:
KeyError: If the specified key is not found in the HDF5 file.
ValueError: If the cutout shape exceeds the dataset dimensions or the starting index is invalid.
"""

with h5py.File(path, "r") as f:

dataset = f[key]
dataset_shape = dataset.shape
print("original data shape", dataset_shape)

# Validate cutout shape
if any(s > d for s, d in zip(shape, dataset_shape)):
raise ValueError(f"Cutout shape {shape} exceeds dataset dimensions {dataset_shape}.")

# Generate random starting index if not provided
if start_index is None:
start_index = tuple(np.random.randint(0, dim - s + 1, size=len(shape)) for dim, s in zip(dataset_shape, shape))

# Calculate end index
end_index = tuple(min(i + s, dim) for i, s, dim in zip(start_index, shape, dataset_shape))

# Load the cutout
cutout = dataset[start_index[0]:end_index[0],
start_index[1]:end_index[1],
start_index[2]:end_index[2]]
print("cutout data shape", cutout.shape)

return cutout


def mito_segmentation_3d() -> None:
patch_shape = (32, 256, 256)
start_index = (10, 32, 64)
data_slice = get_dataset_cutout(INPUT_PATH_LOCAL, shape=patch_shape) #start_index=start_index

device = "cuda" if torch.cuda.is_available() else "cpu"
model_type = "vit_b"
predictor, sam = util.get_sam_model(return_sam=True, model_type=model_type, device=device)

d_size = 3
predictor3d = sam_3d.Predictor3D(sam, d_size)
print(predictor3d)
#breakpoint()
predictor3d.model.forward(torch.from_numpy(data_slice), multimask_output=False, image_size=patch_shape)
# output = predictor3d.model([data_slice], multimask_output=False)#image_size=patch_shape

# predictor3d._hash = util.models().registry[model_type]

# predictor3d.model_name = model_type

# image_embeddings = util.precompute_image_embeddings(predictor3d, volume, EMBEDDINGS_PATH_CLUSTER)
# seg = util.segment_instances_from_embeddings_3d(predictor3d, image_embeddings)

# prediction_filename = os.path.join(EMBEDDINGS_PATH_CLUSTER, f"prediction_{INPUT_PATH_CLUSTER}.h5")
# with h5py.File(prediction_filename, "w") as prediction_file:
# prediction_file.create_dataset("prediction", data=seg)

# visualize
# v = napari.Viewer()
# v.add_image(volume)
# v.add_labels(seg)
# v.add_labels(seg_sam)
# napari.run()



def main():
# automatic segmentation for the data from Lucchi et al. (see 'sam_annotator_3d.py')
# nucleus_segmentation(use_mws=True)
mito_segmentation_3d()

# automatic segmentation for data from the cell tracking challenge (see 'sam_annotator_tracking.py')
# cell_segmentation(use_mws=True)
# cell_segmentation_3d()


if __name__ == "__main__":
main()
3 changes: 2 additions & 1 deletion environment_cpu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@ channels:
- conda-forge
dependencies:
- cpuonly
# This pin is necessary because later nifty versions have import errors on windows.
- nifty =1.2.1=*_4
- imagecodecs
- magicgui
- napari <0.5
- napari
- pip
- pooch
- pyqt
Expand Down
3 changes: 2 additions & 1 deletion environment_gpu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@ channels:
- conda-forge
dependencies:
- imagecodecs
# This pin is necessary because later nifty versions have import errors on windows.
- nifty =1.2.1=*_4
- magicgui
- napari <0.5
- napari
- pip
- pooch
- pyqt
Expand Down
2 changes: 1 addition & 1 deletion examples/annotator_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def wholeslide_annotator(use_finetuned_model):

def main():
# Whether to use the fine-tuned SAM model for light microscopy data.
use_finetuned_model = False
use_finetuned_model = True

# 2d annotator for livecell data
livecell_annotator(use_finetuned_model)
Expand Down
31 changes: 16 additions & 15 deletions finetuning/evaluation/evaluate_amg.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,23 @@
from micro_sam.evaluation.evaluation import run_evaluation
from micro_sam.evaluation.inference import run_amg

from util import get_paths # comment this and create a custom function with the same name to run amg on your data
from util import get_pred_paths, get_default_arguments, VANILLA_MODELS
from util import (
get_paths, # comment this line out and create a custom function with the same name to run amg on your data
get_pred_paths, get_default_arguments
)


def run_amg_inference(dataset_name, model_type, checkpoint, experiment_folder):
def run_amg_inference(dataset_name, model_type, checkpoint, experiment_folder, peft_kwargs):
val_image_paths, val_gt_paths = get_paths(dataset_name, split="val")
test_image_paths, _ = get_paths(dataset_name, split="test")
prediction_folder = run_amg(
checkpoint,
model_type,
experiment_folder,
val_image_paths,
val_gt_paths,
test_image_paths
checkpoint=checkpoint,
model_type=model_type,
experiment_folder=experiment_folder,
val_image_paths=val_image_paths,
val_gt_paths=val_gt_paths,
test_image_paths=test_image_paths,
peft_kwargs=peft_kwargs,
)
return prediction_folder

Expand All @@ -32,12 +35,10 @@ def eval_amg(dataset_name, prediction_folder, experiment_folder):

def main():
args = get_default_arguments()
if args.checkpoint is None:
ckpt = VANILLA_MODELS[args.model]
else:
ckpt = args.checkpoint

prediction_folder = run_amg_inference(args.dataset, args.model, ckpt, args.experiment_folder)
peft_kwargs = {"rank": args.peft_rank, "module": args.peft_module}
prediction_folder = run_amg_inference(
args.dataset, args.model, args.checkpoint, args.experiment_folder, peft_kwargs
)
eval_amg(args.dataset, prediction_folder, args.experiment_folder)


Expand Down
27 changes: 16 additions & 11 deletions finetuning/evaluation/evaluate_instance_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,25 @@
from micro_sam.evaluation.evaluation import run_evaluation
from micro_sam.evaluation.inference import run_instance_segmentation_with_decoder

from util import get_paths # comment this and create a custom function with the same name to run ais on your data
from util import get_pred_paths, get_default_arguments
from util import (
get_paths, # comment this line out and create a custom function with the same name to run ais on your data
get_pred_paths, get_default_arguments
)


def run_instance_segmentation_with_decoder_inference(dataset_name, model_type, checkpoint, experiment_folder):
def run_instance_segmentation_with_decoder_inference(
dataset_name, model_type, checkpoint, experiment_folder, peft_kwargs,
):
val_image_paths, val_gt_paths = get_paths(dataset_name, split="val")
test_image_paths, _ = get_paths(dataset_name, split="test")
prediction_folder = run_instance_segmentation_with_decoder(
checkpoint,
model_type,
experiment_folder,
val_image_paths,
val_gt_paths,
test_image_paths
checkpoint=checkpoint,
model_type=model_type,
experiment_folder=experiment_folder,
val_image_paths=val_image_paths,
val_gt_paths=val_gt_paths,
test_image_paths=test_image_paths,
peft_kwargs=peft_kwargs,
)
return prediction_folder

Expand All @@ -32,9 +37,9 @@ def eval_instance_segmentation_with_decoder(dataset_name, prediction_folder, exp

def main():
args = get_default_arguments()

peft_kwargs = {"rank": args.peft_rank, "module": args.peft_module}
prediction_folder = run_instance_segmentation_with_decoder_inference(
args.dataset, args.model, args.checkpoint, args.experiment_folder
args.dataset, args.model, args.checkpoint, args.experiment_folder, peft_kwargs,
)
eval_instance_segmentation_with_decoder(args.dataset, prediction_folder, args.experiment_folder)

Expand Down
Loading

0 comments on commit a9a8003

Please sign in to comment.