Skip to content

Commit

Permalink
Merge pull request #40 from Hendrik-code/labeling_step
Browse files Browse the repository at this point in the history
Adding Vertebra Labeling Phase after both segmentation steps
  • Loading branch information
Hendrik-code authored Jan 10, 2025
2 parents 6fef104 + 1680786 commit 0adbb54
Show file tree
Hide file tree
Showing 29 changed files with 1,342 additions and 132 deletions.
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ nnunetv2 = "2.4.2"
TPTBox = "^0.2.1"
antspyx = "0.4.2"
rich = "^13.6.0"
monai="^1.3.0"
TypeSaveArgParse="^1.0.1"


[tool.poetry.dev-dependencies]
Expand Down
3 changes: 2 additions & 1 deletion spineps/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from spineps.entrypoint import entry_point
from spineps.models import get_instance_model, get_semantic_model
from spineps.get_models import get_instance_model, get_labeling_model, get_semantic_model
from spineps.phase_instance import predict_instance_mask
from spineps.phase_labeling import perform_labeling_step
from spineps.phase_post import phase_postprocess_combined
from spineps.phase_semantic import predict_semantic_mask
from spineps.seg_model import Segmentation_Model
Expand Down
Empty file.
140 changes: 140 additions & 0 deletions spineps/architectures/pl_densenet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
import os
import sys
from dataclasses import dataclass
from pathlib import Path

import pytorch_lightning as pl
import torch
from monai.networks.nets import DenseNet169
from torch import nn
from TypeSaveArgParse import Class_to_ArgParse


@dataclass
class ARGS_MODEL(Class_to_ArgParse):
classification_conv: bool = False
classification_linear: bool = True
#
n_epoch: int = 100
lr: float = 1e-4
l2_regularization_w: float = 1e-6 # 1e-5 was ok
scheduler_endfactor: float = 1e-3
#
in_channel: int = 1 # 1 for img, will be set elsewhere
not_pretrained: bool = True
#
mse_weighting: float = 0.0
dropout: float = 0.05
weight_decay: float = 0 # 1e-4
#
num_classes: int | None = None # Filled elsewhere
n_channel_p_group: int | None = None # Filled elsewhere


class PLClassifier(pl.LightningModule):
def __init__(self, opt: ARGS_MODEL, group_2_n_channel: dict[str, int]):
super().__init__()
self.opt = opt
assert isinstance(opt.num_classes, int), opt.num_classes
self.num_classes: int = opt.num_classes
self.group_2_n_channel = group_2_n_channel
# save hyperparameter, everything below not visible
self.save_hyperparameters()

self.net, linear_in = get_architecture(
DenseNet169, opt.in_channel, opt.num_classes, pretrained=False, remove_classification_head=True
)
self.classification_heads = self.build_classification_heads(linear_in, opt.classification_conv, opt.classification_linear)
self.classification_keys = list(self.classification_heads.keys())
self.mse_weighting = opt.mse_weighting

self.metrics_to_log = ["f1", "mcc", "acc", "auroc", "f1_avg"]
self.metrics_to_log_overall = ["f1", "f1_avg"]

self.train_step_outputs = []
self.val_step_outputs = []
self.softmax = nn.Softmax(dim=1) # use this group-wise?
self.sigmoid = nn.Sigmoid()
self.cross_entropy = nn.CrossEntropyLoss()
self.mse = nn.MSELoss(reduction="none")
self.l2_reg_w = opt.l2_regularization_w
print(f"{self._get_name()} loaded with", opt)

def forward(self, x):
features = self.net(x)
return {k: v(features) for k, v in self.classification_heads.items()}

def training_step(self, batch, _):
img, logits, logits_soft, pred_cls, label_onehot, label, losses, loss = self._shared_step(batch)
# Log
self.log("loss/train_loss", loss, batch_size=img.shape[0], prog_bar=True)
#
for k, v in losses.items():
for kk, kv in v.items():
self.log(f"loss_train_{k}/{kk}", kv.item(), batch_size=img.shape[0], prog_bar=False)
# self._shared_metric_append({"pred": pred_cls, "gt": label}, self.train_step_outputs)
self.train_step_outputs.append({"preds": pred_cls, "labels": label})
return loss

def validation_step(self, batch, _):
img, logits, logits_soft, pred_cls, label_onehot, label, losses, loss = self._shared_step(batch)
self.log("loss/val_loss", loss)
self.val_step_outputs.append({"preds": pred_cls, "labels": label})
return loss

def _shared_step(self, batch):
img = batch["img"]
label = batch["label"] # onehot
#
gt_label = {k: torch.max(v, 1)[1] for k, v in label.items()}
logits = self.forward(img)
#
logits_soft = {k: self.softmax(v) for k, v in logits.items()}
pred_cls = {k: torch.max(v, 1)[1] for k, v in logits_soft.items()}

losses = {k: self.loss(logits[k], label[k]) for k in label.keys()}
loss = self.loss_merge(losses)
return img, logits, logits_soft, pred_cls, label, gt_label, losses, loss

def build_classification_heads(self, linear_in: int, convolution_first: bool, fully_connected: bool):
def construct_one_head(output_classes: int):
modules = []
n_channel = linear_in
n_channel_next = linear_in
if convolution_first:
n_channel_next = n_channel // 2
modules.append(nn.Conv3d(n_channel, n_channel_next, kernel_size=(3, 3, 3), device="cuda:0"))
n_channel = n_channel_next
if fully_connected:
n_channel_next = n_channel // 2
modules.append(nn.Linear(n_channel, n_channel_next, device="cuda:0"))
modules.append(nn.ReLU())
n_channel = n_channel_next
modules.append(nn.Linear(n_channel, output_classes, device="cuda:0"))

return nn.Sequential(*modules)

return nn.ModuleDict({k: construct_one_head(v) for k, v in self.group_2_n_channel.items()})

def __str__(self) -> str:
return "VertebraLabelingModel"


def get_architecture(
model,
in_channel: int = 1,
out_channel: int = 1,
pretrained: bool = True,
remove_classification_head: bool = True,
):
model = model(
spatial_dims=3,
in_channels=in_channel,
out_channels=out_channel,
pretrained=pretrained,
)
linear_infeatures = 0
linear_infeatures = model.class_layers[-1].in_features
if remove_classification_head:
model.class_layers = model.class_layers[:-1]
return model, linear_infeatures
19 changes: 9 additions & 10 deletions spineps/Unet3D/pl_unet.py → spineps/architectures/pl_unet.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
from typing import Any

import pytorch_lightning as pl
import torch
import torch.nn as nn
import torchmetrics.functional as mF # noqa: N812
from torch import nn
from torch.optim import lr_scheduler
import torchmetrics.functional as mF

from spineps.Unet3D.unet3D import Unet3D
from spineps.architectures.unet3D import Unet3D


class PLNet(pl.LightningModule):
def __init__(self, opt=None, do2D: bool = False, *args: Any, **kwargs: Any) -> None:
def __init__(self, opt=None, do2D: bool = False, *args: Any, **kwargs: Any) -> None: # noqa: N803, ARG002
super().__init__()
self.save_hyperparameters()

Expand Down Expand Up @@ -63,7 +64,7 @@ def on_train_epoch_end(self) -> None:
self.logger.experiment.add_text("train_dice_p_cls", str(metrics["dice_p_cls"].tolist()), self.current_epoch)
self.train_step_outputs.clear()

def validation_step(self, batch, batch_idx):
def validation_step(self, batch, _):
loss, logits, gt, pred_cls = self._shared_step(batch["target"], batch["class"], detach2cpu=True)
loss = loss.detach().cpu()
metrics = self._shared_metric_step(loss, logits, gt, pred_cls)
Expand All @@ -90,7 +91,7 @@ def configure_optimizers(self):
return {"optimizer": optimizer}

def loss(self, logits, gt):
return 0.0 # TODO don't use this for training
return logits, gt # TODO don't use this for training

def _shared_step(self, target, gt, detach2cpu: bool = False):
logits = self.forward(target)
Expand All @@ -108,9 +109,9 @@ def _shared_step(self, target, gt, detach2cpu: bool = False):
pred_cls = pred_cls.detach().cpu()
return loss, logits, gt, pred_cls

def _shared_metric_step(self, loss, logits, gt, pred_cls):
def _shared_metric_step(self, loss, _, gt, pred_cls):
dice = mF.dice(pred_cls, gt, num_classes=self.n_classes)
diceFG = mF.dice(pred_cls, gt, num_classes=self.n_classes, ignore_index=0)
diceFG = mF.dice(pred_cls, gt, num_classes=self.n_classes, ignore_index=0) # noqa: N806
dice_p_cls = mF.dice(pred_cls, gt, average=None, num_classes=self.n_classes)
return {"loss": loss.detach().cpu(), "dice": dice, "diceFG": diceFG, "dice_p_cls": dice_p_cls}

Expand All @@ -123,8 +124,6 @@ def _shared_metric_append(self, metrics, outputs):
def _shared_cat_metrics(self, outputs):
results = {}
for m, v in outputs.items():
# v = np.asarray(v)
# print(m, v.shape)
stacked = torch.stack(v)
results[m] = torch.mean(stacked) if m != "dice_p_cls" else torch.mean(stacked, dim=0)
return results
Expand Down
Loading

0 comments on commit 0adbb54

Please sign in to comment.