diff --git a/recognition/prostate2dnet_46978116/README.MD b/recognition/prostate2dnet_46978116/README.MD new file mode 100644 index 0000000000..8379ff8d4e --- /dev/null +++ b/recognition/prostate2dnet_46978116/README.MD @@ -0,0 +1,104 @@ +# Segmenting HipMRI Study for prostate cancer radiotherapy Data + +In this report, we will attempt to segment HipMRI data based on the part of the body the image depicts with the help of 2D U-Net Convolutional Neural Network. + +# Understanding the Dataset + +The dataset consists of 2D MRI slices of the male pelvis collected from a radiation therapy study at Calvary Mater Newcastle Hospital. Each MRI image has a corresponding segmentation image, which we will refer to as the mask, that segments the MRI image based on the part of the body depicted in each segment. There are six classes to these segments: the background, bladder, body, bone, rectum and prostate. Each image is roughly 256 x 128 pixels although some are slightly larger than this so they have all been resized to 256 x 256. Our task is to develop a method that accurately segments MRI data of the male pelvis based on training data collected from these images and their masks. + +# The 2D U-Net Model + +The chosen model to learn this segmentation is a 2D U-Net. The 2D U-Net a deep neural network developed by Ronneburg, etal in 2015. This model shows promising performance in segmentation of medical data, making it an ideal fit for the task. The 2D U-Net utilisizes an encoder-decoder structure with skip connections. This makes it effective in computervision segmentation problems as it can capture the spatial information of the data well and retains information through the skip connections between the encoder and decoder layers. The model also trains quickly with high accuracy. + +![2D U-Net Architecture](images\2dunet_architecture.png) + +The architecture consists of two 2D convolutional layers which are each followed by max pooling layers. The activation function is ReLu and a 2D Batch normalization is used as well. Although Batch normalization was not used in the original paper (because it wasn't developed yet) we have used it here to improve performance. The data is downsampled using these blocks until it reaches a bottleneck at 512 feature channels. Here, it is upscaled back to the original image size using upsampling layers with skip connections to retain information while reaching a detailed resolution. The final layer has 6 outchannels which corresponds with the 6 classes of segments. + +# Output and Results + +After training on the training set with early stoppage the the dice co-efficients on the test set using cross-entropy loss and minimal transformations: + +Class 0: 0.9945 +Class 1: 0.9768 +Class 2: 0.8689 +Class 3: 0.6563 +Class 4: 0.3489 +Class 5: 0.2344 +Average Dice Score: 0.6800 + +Below are the plots of the training performance: + +![Training Performance CrossEntropyLoss](images\graphs\crossplot.png) + +And the predictions made by the model: + +![Best Predictions [Epoch 17] CrossEntropyLoss (Classes: 0, 1, 2)](images\predictions\bestcross10\epoch_test_classes_0-2.png) + +![Best Predictions [Epoch 17] CrossEntropyLoss (Classes: 3, 4, 5)](images\predictions\bestcross10\epoch_test_classes_3-5.png) + + +close up predictions: + + +![Best Predictions Close Up [Epoch 17] CrossEntropyLoss (Classes: 0, 1, 2)](images\predictions\bestcross2\epoch_test_classes_0-2.png) + +![Best Predictions Close Up [Epoch 17] CrossEntropyLoss (Classes: 3, 4, 5)](images\predictions\bestcross2\epoch_test_classes_3-5.png) + + +On the left is the MRI image, the middle is real segment for the corresponding MRI image and the right is the prediction from the model. + +Notably, the model underperformed in segmenting for class 3, 4 and 5. These classes were identified early on as having poor performance compared to other classes which can be attributed to their very low prevelance rate. The vast majority of MRI slides do not contain these segments and if they do its very small compared to the other classes. + +To try to combat this transformations were performed on the images to try to improve segmenting of these classes such as rotations and zooming. Furthermore, the loss function was changed to directly include the dice loss (1 - dice coefficient) for each class. This dice loss had an added penalty when the co-efficient was below 0.75 to further incentivize the model to learn how to segment class 4 and 5. Below is the test dice coefficients after these changes, showing a very large improvement especially in the problem classes 3, 4, and 5: + +Class 0: 0.9863 +Class 1: 0.9777 +Class 2: 0.8820 +Class 3: 0.7254 +Class 4: 0.6833 +Class 5: 0.8711 +Average Dice Score: 0.8543 + +And the plots of training performance and predictions also show improvement: + +![Training Performance DiceLoss](images\graphs\diceplot.png) + +The predictions made by the model: + +![Best Predictions [Epoch 40] DiceLoss (Classes: 0, 1, 2)](images\predictions\bestdice10\epoch_test_classes_0-2.png) + +![Best Predictions [Epoch 40] DiceLoss (Classes: 3, 4, 5)](images\predictions\bestdice10\epoch_test_classes_3-5.png) + + +close up predictions: + +![Best Predictions Close Up [Epoch 40] DiceLoss (Classes: 0, 1, 2)](images\predictions\bestdice2\epoch_test_classes_0-2.png) + +![Best Predictions Close Up [Epoch 40] DiceLoss (Classes: 3, 4, 5)](images\predictions\bestdice2\epoch_test_classes_3-5.png) + + +# Requirements + +torch +torchvision +scikit-image +tqdm +nibabel +numpy +albumentatinos +matplotlib + +# References + +[1] O. Ronneberger, P. Fischer, and T. Brox, “U-Net: Convolutional Networks for Biomedical Image Segmentation,” in Medical Image Computing and Computer-Assisted Intervention – MICCAI 2015, ser. Lecture Notes in Computer Science, N. Navab, J. Hornegger, W. M. Wells, and A. F. Frangi, Eds. Cham: Springer International Publishing, 2015, pp. 234–241. Available: https://doi.org/10.48550/arXiv.1505.04597 + +[2] J. Schmidt, “Creating and training a U-Net model with PyTorch for 2D & 3D semantic segmentation: Dataset building,” Towards Data Science, Dec. 2, 2020. [Online]. Available: https://towardsdatascience.com/creating-and-training-a-u-net-model-with-pytorch-for-2d-3d-semantic-segmentation-dataset-fb1f7f80fe55. + +[3] A. Persson, “PyTorch Image Segmentation Tutorial with U-NET: everything from scratch baby,” YouTube, Feb. 23, 2021. [Online]. Available: https://www.youtube.com/watch?v=IHq1t7NxS8k. + +[4] Esri, “How U-net works,” ArcGIS API for Python Documentation. [Online]. Available: https://developers.arcgis.com/python/latest/guide/how-unet-works/. + +[5] J. Dowling and P. Greer, “Labelled weekly MR images of the male pelvis,” CSIRO Data Access Portal, Sep. 20, 2021. [Online]. Available: https://doi.org/10.25919/45t8-p065. + + + diff --git a/recognition/prostate2dnet_46978116/dataset.py b/recognition/prostate2dnet_46978116/dataset.py new file mode 100644 index 0000000000..a7b99e8549 --- /dev/null +++ b/recognition/prostate2dnet_46978116/dataset.py @@ -0,0 +1,106 @@ + +import numpy as np +import nibabel as nib +import os +import torch +from tqdm import tqdm +from torch.utils.data import Dataset +from skimage.transform import resize +import torchvision.transforms as transforms + + +def to_channels(arr: np.ndarray, dtype=np.uint8) -> np.ndarray: + channels = np.unique(arr) + res = np.zeros(arr.shape + (len(channels),), dtype=dtype) + for c in channels: + c = int(c) + res[..., c:c+1][arr == c] = 1 + + return res + + +# load medical image functions +def load_data_2D(imageNames, normImage = False, categorical = False, dtype = np.float32, getAffines = False, early_stop = False): + ''' + Load medical image data from names, cases list provided into a list for each. + This function pre - allocates 4D arrays for conv2d to avoid excessive memory & + usage. + normImage: bool (normalise the image 0.0 -1.0) + early_stop: Stop loading pre - maturely, leaves arrays mostly empty, for quick & + loading and testing scripts. + ''' + affines = [] + + # get fixed size + num = len(imageNames) + first_case = nib.load(imageNames[0]).get_fdata(caching = 'unchanged') + if len(first_case.shape) == 3: + first_case = first_case [:,:,0] # sometimes extra dims, remove + if categorical: + first_case = to_channels(first_case, dtype = dtype) + rows, cols, channels = first_case.shape + images = np.zeros((num, rows, cols, channels), dtype = dtype) + else: + rows, cols = first_case.shape + images = np.zeros((num, rows, cols), dtype = dtype) + + for i, inName in enumerate(imageNames): + niftiImage = nib.load(inName) + inImage = niftiImage.get_fdata(caching = 'unchanged') # read disk only + affine = niftiImage.affine + if len(inImage.shape) == 3: + inImage = inImage[:,:,0] # sometimes extra dims in HipMRI_study data + inImage = inImage.astype(dtype) + if normImage: + #~ inImage = inImage / np.linalg.norm(inImage) + #~ inImage = 255. * inImage / inImage.max() + inImage = (inImage - inImage.mean()) / inImage.std() + if categorical: + inImage = to_channels(inImage, dtype = dtype) + images[i,:,:,:] = inImage + else: + images[i,:,:] = inImage + + affines.append(affine) + if i > 20 and early_stop: + break + + if getAffines: + return torch.tensor(images, dtype = torch.float32), affines + else: + return torch.tensor(images, dtype = torch.float32) + +# Dataset structure for loading images and masks into dataloader +class ProstateDataset(Dataset): + def __init__(self, image_path, mask_path, norm_image=False, transform=None, target_size=(128, 64)): + self.transform = transform + self.image_path = image_path + self.mask_path = mask_path + + # list of paths + self.images = os.listdir(self.image_path) + self.masks = os.listdir(self.mask_path) + + + self.target_size = target_size + self.transform = transform + + def __len__(self): + return len(self.images) + + def __getitem__(self, idx): + # load with helper + img_pth = os.path.join(self.image_path,self.images[idx]) + mask_pth= os.path.join(self.mask_path,self.images[idx].replace('case', 'seg')) + image = load_data_2D([img_pth], normImage=True) + mask = load_data_2D([mask_pth]) + + # Apply transformations + image = transforms.Resize((256, 256))(image) + mask = transforms.Resize((256, 256))(mask) + + mask = mask.long() + + + + return image, mask diff --git a/recognition/prostate2dnet_46978116/images/2dunet_architecture.png b/recognition/prostate2dnet_46978116/images/2dunet_architecture.png new file mode 100644 index 0000000000..876e18c6e7 Binary files /dev/null and b/recognition/prostate2dnet_46978116/images/2dunet_architecture.png differ diff --git a/recognition/prostate2dnet_46978116/images/graphs/crossplot.png b/recognition/prostate2dnet_46978116/images/graphs/crossplot.png new file mode 100644 index 0000000000..fb045b0d25 Binary files /dev/null and b/recognition/prostate2dnet_46978116/images/graphs/crossplot.png differ diff --git a/recognition/prostate2dnet_46978116/images/graphs/diceplot.png b/recognition/prostate2dnet_46978116/images/graphs/diceplot.png new file mode 100644 index 0000000000..dfb24ab302 Binary files /dev/null and b/recognition/prostate2dnet_46978116/images/graphs/diceplot.png differ diff --git a/recognition/prostate2dnet_46978116/images/predictions/bestcross10/epoch_test_classes_0-2.png b/recognition/prostate2dnet_46978116/images/predictions/bestcross10/epoch_test_classes_0-2.png new file mode 100644 index 0000000000..4c818579f5 Binary files /dev/null and b/recognition/prostate2dnet_46978116/images/predictions/bestcross10/epoch_test_classes_0-2.png differ diff --git a/recognition/prostate2dnet_46978116/images/predictions/bestcross10/epoch_test_classes_3-5.png b/recognition/prostate2dnet_46978116/images/predictions/bestcross10/epoch_test_classes_3-5.png new file mode 100644 index 0000000000..d114f367c3 Binary files /dev/null and b/recognition/prostate2dnet_46978116/images/predictions/bestcross10/epoch_test_classes_3-5.png differ diff --git a/recognition/prostate2dnet_46978116/images/predictions/bestcross2/epoch_test_classes_0-2.png b/recognition/prostate2dnet_46978116/images/predictions/bestcross2/epoch_test_classes_0-2.png new file mode 100644 index 0000000000..179d42b20b Binary files /dev/null and b/recognition/prostate2dnet_46978116/images/predictions/bestcross2/epoch_test_classes_0-2.png differ diff --git a/recognition/prostate2dnet_46978116/images/predictions/bestcross2/epoch_test_classes_3-5.png b/recognition/prostate2dnet_46978116/images/predictions/bestcross2/epoch_test_classes_3-5.png new file mode 100644 index 0000000000..85518c0e96 Binary files /dev/null and b/recognition/prostate2dnet_46978116/images/predictions/bestcross2/epoch_test_classes_3-5.png differ diff --git a/recognition/prostate2dnet_46978116/images/predictions/bestdice10/epoch_test_classes_0-2.png b/recognition/prostate2dnet_46978116/images/predictions/bestdice10/epoch_test_classes_0-2.png new file mode 100644 index 0000000000..66f6cca8fb Binary files /dev/null and b/recognition/prostate2dnet_46978116/images/predictions/bestdice10/epoch_test_classes_0-2.png differ diff --git a/recognition/prostate2dnet_46978116/images/predictions/bestdice10/epoch_test_classes_3-5.png b/recognition/prostate2dnet_46978116/images/predictions/bestdice10/epoch_test_classes_3-5.png new file mode 100644 index 0000000000..f93f9f04cd Binary files /dev/null and b/recognition/prostate2dnet_46978116/images/predictions/bestdice10/epoch_test_classes_3-5.png differ diff --git a/recognition/prostate2dnet_46978116/images/predictions/bestdice2/epoch_test_classes_0-2.png b/recognition/prostate2dnet_46978116/images/predictions/bestdice2/epoch_test_classes_0-2.png new file mode 100644 index 0000000000..7b5e15f34f Binary files /dev/null and b/recognition/prostate2dnet_46978116/images/predictions/bestdice2/epoch_test_classes_0-2.png differ diff --git a/recognition/prostate2dnet_46978116/images/predictions/bestdice2/epoch_test_classes_3-5.png b/recognition/prostate2dnet_46978116/images/predictions/bestdice2/epoch_test_classes_3-5.png new file mode 100644 index 0000000000..90255f5bd7 Binary files /dev/null and b/recognition/prostate2dnet_46978116/images/predictions/bestdice2/epoch_test_classes_3-5.png differ diff --git a/recognition/prostate2dnet_46978116/images/predictions/test_cross_scores.txt b/recognition/prostate2dnet_46978116/images/predictions/test_cross_scores.txt new file mode 100644 index 0000000000..474c9ae4b0 --- /dev/null +++ b/recognition/prostate2dnet_46978116/images/predictions/test_cross_scores.txt @@ -0,0 +1,7 @@ +Class 0: 0.9945 +Class 1: 0.9768 +Class 2: 0.8689 +Class 3: 0.6563 +Class 4: 0.3489 +Class 5: 0.2344 +Average Dice Score: 0.6800 diff --git a/recognition/prostate2dnet_46978116/images/predictions/test_dice_scores.txt b/recognition/prostate2dnet_46978116/images/predictions/test_dice_scores.txt new file mode 100644 index 0000000000..87c4a7bdcd --- /dev/null +++ b/recognition/prostate2dnet_46978116/images/predictions/test_dice_scores.txt @@ -0,0 +1,7 @@ +Class 0: 0.9863 +Class 1: 0.9777 +Class 2: 0.8820 +Class 3: 0.7254 +Class 4: 0.6833 +Class 5: 0.8711 +Average Dice Score: 0.8543 diff --git a/recognition/prostate2dnet_46978116/modules.py b/recognition/prostate2dnet_46978116/modules.py new file mode 100644 index 0000000000..418441270a --- /dev/null +++ b/recognition/prostate2dnet_46978116/modules.py @@ -0,0 +1,84 @@ +import torch +import torch.nn as nn +import torchvision.transforms.functional as TF + + +class UNet(nn.Module): + def __init__(self): + super(UNet, self).__init__() + + # Down double conv layers + self.down1 = DoubleConv(in_channels=1, out_channels=64) + self.down2 = DoubleConv(in_channels=64, out_channels=128) + self.down3 = DoubleConv(in_channels=128, out_channels=256) + + # Bottle neck + self.down4 = DoubleConv(in_channels=256, out_channels=512) + + #max pool layer + self.max_pool = nn.MaxPool2d(kernel_size=2, stride=2) + + # Up transpose layers + Double Conv + self.up_trans1 = nn.ConvTranspose2d(in_channels=512,out_channels=256,kernel_size=2,stride=2) + self.up1 = DoubleConv(512, 256) + + + self.up_trans2 = nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=2,stride=2) + self.up2 = DoubleConv(256, 128) + + self.up_trans3 = nn.ConvTranspose2d(in_channels=128,out_channels=64,kernel_size=2,stride=2) + self.up3 = DoubleConv(128, 64) + + + self.final = nn.Conv2d(in_channels=64,out_channels=6,kernel_size=1) + + + def forward(self, initial): + + + # down + c1 = self.down1(initial) + p1 = self.max_pool(c1) + + c2 = self.down2(p1) + p2 = self.max_pool(c2) + + c3 = self.down3(p2) + p3 = self.max_pool(c3) + + + + c4 = self.down4(p3) + + # upsample + t1 = self.up_trans1(c4) + d1 = self.up1(torch.cat([t1, c3], 1)) + + t2 = self.up_trans2(d1) + d2 = self.up2(torch.cat([t2, c2], 1)) + + t3 = self.up_trans3(d2) + d3 = self.up3(torch.cat([t3, c1], 1)) + + + # output + out = self.final(d3) + return out + + +class DoubleConv(nn.Module): + def __init__(self, in_channels, out_channels): + super(DoubleConv, self).__init__() + self.conv = nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size=3, stride = 1, padding=1, bias=False), + nn.BatchNorm2d(out_channels), + nn.ReLU(inplace=True), + + nn.Conv2d(out_channels, out_channels, kernel_size=3, stride= 1, padding=1, bias=False), + nn.BatchNorm2d(out_channels), + nn.ReLU(inplace=True) + ) + + + def forward(self,x): + return self.conv(x) diff --git a/recognition/prostate2dnet_46978116/predict.py b/recognition/prostate2dnet_46978116/predict.py new file mode 100644 index 0000000000..935e7a891d --- /dev/null +++ b/recognition/prostate2dnet_46978116/predict.py @@ -0,0 +1,89 @@ +# predict.py + +import torch +import albumentations as A +from albumentations.pytorch import ToTensorV2 +from torch.utils.data import DataLoader +from dataset import ProstateDataset +from modules import UNet +from train import CombinedLoss, validate_fn, save_img, plot_metrics +import os +import numpy as np +import matplotlib.pyplot as plt +from tqdm import tqdm + +# Configuration Parameters +batch_size = 64 +n_workers = 2 +pin = True +device = "cuda" if torch.cuda.is_available() else "cpu" +img_height = 256 +img_width = 128 +num_classes = 6 +best_model_path = "savedmodels/bestDice.pth" # Path to the saved best model +test_image_dir = 'keras_slices_data/keras_slices_test' # Ensure these match test folder +test_mask_dir = 'keras_slices_data/keras_slices_seg_test'# +output_folder = 'dice_test_predictions' + +# + + +def main(): + # Define transformation pipeline for testing (no augmentations, only resizing and normalization) + test_transform = A.Compose([ + A.Resize(height=img_height, width=img_width), + A.Normalize(mean=[0.0], std=[1.0], max_pixel_value=255.0), + ToTensorV2() + ]) + + # Instantiate the test dataset and DataLoader + test_dataset = ProstateDataset( + image_path=test_image_dir, + mask_path=test_mask_dir, + norm_image=True, + transform=test_transform + ) + test_loader = DataLoader( + test_dataset, + batch_size=batch_size, + shuffle=False, + num_workers=n_workers, + pin_memory=pin + ) + + # Initialize the model and load the best saved weights + model = UNet().to(device=device) + if os.path.exists(best_model_path): + model.load_state_dict(torch.load(best_model_path, map_location=device)) + print(f"Loaded best model weights from '{best_model_path}'.") + else: + print(f"Best model not found at '{best_model_path}'. Exiting.") + return + + # Define the loss function (same as used during training) + ce_weights = torch.tensor([1, 1, 1, 2, 10, 4], dtype=torch.float).to(device) + loss_fn = CombinedLoss(ce_weight=ce_weights) + + # Evaluate the model on the test set and save dice scores + print("Evaluating the model on the test set...") + test_loss, test_dice = validate_fn(test_loader, model, loss_fn) + print(f"Test Loss: {test_loss:.4f}") + print(f"Test Dice Coefficients per class: {test_dice}") + + # Save prediction images for qualitative assessment + print(f"Saving prediction images to '{output_folder}'...") + save_img(test_loader, model, folder=output_folder, device=device, num_classes=num_classes, epoch='test') + save_img(test_loader, model, folder=output_folder + 'small', device=device, num_classes=num_classes, epoch='test',max_images_per_class=2) + # Compute overall Dice score (average across classes) + average_dice = np.mean(list(test_dice.values())) + print(f"Average Dice Score across all classes: {average_dice:.4f}") + + # save dice scores to file + with open(os.path.join(output_folder, 'test_dice_scores.txt'), 'w') as f: + for cls, score in test_dice.items(): + f.write(f"Class {cls}: {score:.4f}\n") + f.write(f"Average Dice Score: {average_dice:.4f}\n") + print(f"Dice scores saved to '{output_folder}/test_dice_scores.txt'.") + +if __name__ == '__main__': + main() diff --git a/recognition/prostate2dnet_46978116/train.py b/recognition/prostate2dnet_46978116/train.py new file mode 100644 index 0000000000..3df8a9d0d5 --- /dev/null +++ b/recognition/prostate2dnet_46978116/train.py @@ -0,0 +1,406 @@ +import torch +import torchvision +import albumentations as A +from albumentations.pytorch import ToTensorV2 +from tqdm import tqdm +import torch.nn as nn +import torch.optim as optim +from dataset import ProstateDataset +from torch.utils.data import DataLoader +from modules import UNet +import os +import matplotlib.pyplot as plt +import numpy as np +import torch.optim.lr_scheduler as lr_scheduler +import matplotlib.pyplot as plt + + +batch_size = 64 +N_epochs = 100 +n_workers = 2 +pin = True +device = "cuda" if torch.cuda.is_available() else "cpu" +load_model = False +img_height = 256 +img_width = 128 +learning_rate = 0.0005 + + + +# path to image folders ensure they are same +train_image_dir = 'keras_slices_data/keras_slices_train' +train_mask_dir = 'keras_slices_data/keras_slices_seg_train' +val_image_dir = 'keras_slices_data/keras_slices_validate' +val_mask_dir = 'keras_slices_data/keras_slices_seg_validate' + +def train_fn(loader,model,optimizer,loss_fn,scaler): + loop = tqdm(loader) + total_loss = 0.0 + + for batch_idx, (data,targets) in enumerate(loop): + data = data.to(device=device) + targets = targets.float().squeeze(1).to(device=device) + targets = targets.long() + + with torch.amp.autocast(device_type=device): + predictions = model(data) + loss = loss_fn(predictions, targets) + + optimizer.zero_grad() + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() + total_loss += loss.item() + loop.set_postfix(loss=loss.item()) + return total_loss / len(loader) + + +class DiceLoss(nn.Module): + def __init__(self, smooth=1e-5, ignore_index=None, num_classes=6, threshold=0.75, k=20.0): + """ + Initializes the DiceLoss. + + Parameters: + smooth (float): Smoothing factor to avoid division by zero. + ignore_index (int, optional): Specifies a target value that is ignored + and does not contribute to the input gradient. + num_classes (int): Number of segmentation classes. + threshold (float): Dice score threshold below which additional penalty is applied. + k (float): Steepness parameter for the sigmoid function to approximate the step penalty. + """ + super(DiceLoss, self).__init__() + self.smooth = smooth + self.ignore_index = ignore_index + self.num_classes = num_classes + self.threshold = threshold + self.k = k + self.last_dice_coeff = None + + def forward(self, inputs, targets): + """ + Forward pass for DiceLoss. + + Parameters: + inputs (torch.Tensor): Predicted logits from the model (before softmax). + Shape: [batch_size, num_classes, H, W] + targets (torch.Tensor): Ground truth mask. + Shape: [batch_size, H, W] + + Returns: + torch.Tensor: Computed Dice loss. + """ + # Ensure targets are of type long + targets = targets.long() + + # Convert targets to one-hot encoding + targets_one_hot = torch.zeros_like(inputs) + targets_one_hot.scatter_(1, targets.unsqueeze(1), 1) + + if self.ignore_index is not None: + mask = targets != self.ignore_index + inputs = inputs * mask.unsqueeze(1) + targets_one_hot = targets_one_hot * mask.unsqueeze(1) + + # Apply softmax to get probabilities + inputs = torch.softmax(inputs, dim=1) + + # Flatten the tensors + inputs = inputs.view(inputs.size(0), inputs.size(1), -1) # [batch_size, num_classes, H*W] + targets_one_hot = targets_one_hot.view(targets_one_hot.size(0), targets_one_hot.size(1), -1) # [batch_size, num_classes, H*W] + + # Compute intersection and union + intersection = (inputs * targets_one_hot).sum(-1) # [batch_size, num_classes] + total = inputs.sum(-1) + targets_one_hot.sum(-1) # [batch_size, num_classes] + + # Compute Dice coefficient + dice_coeff = (2. * intersection + self.smooth) / (total + self.smooth) # [batch_size, num_classes] + + # Average Dice coefficient over batch + dice_coeff = dice_coeff.mean(dim=0) # [num_classes] + + # Store the Dice coefficients + self.last_dice_coeff = dice_coeff.detach().cpu().numpy() + + # Compute penalty using sigmoid + penalty = torch.sigmoid(self.k * (self.threshold - dice_coeff)) # [num_classes] + + # Compute cost per class + cost = 1 - dice_coeff + penalty # [num_classes] + + # Compute dice_cost as the sum of squared costs + dice_cost = torch.sum(cost ** 2) + + return dice_cost + + def get_last_dice_coeff(self): + return self.last_dice_coeff + + +# Combined Loss (cross with dice) +class CombinedLoss(nn.Module): + def __init__(self, ce_weight=None, dice_weight=1.0, ce_weight_factor=1.0, dice_weight_factor=3.0, cross=False): + super(CombinedLoss, self).__init__() + self.ce_loss = nn.CrossEntropyLoss(weight=ce_weight) + dice_cost = 0 + self.dice_loss = DiceLoss() + self.ce_weight = ce_weight_factor + self.dice_weight = dice_weight_factor + self.last_dice_coeff = None + self.cross = cross + + def forward(self, inputs, targets): + dl = self.dice_loss(inputs, targets) + self.last_dice_coeff = self.dice_loss.get_last_dice_coeff() + ce = self.ce_loss(inputs, targets) + if not self.cross: + return self.ce_weight * ce + self.dice_weight * dl + else: + return ce + + + def get_last_dice_coeff(self): + """ + Retrieves the last computed Dice coefficients. + """ + return self.last_dice_coeff + +def main(): + train_trainsform = A.Compose( + [ A.Resize(height=img_height,width=img_width), + A.Rotate(limit=35, p=1.0), + A.HorizontalFlip(p=0.5), + A.VerticalFlip(p=0.1), + A.Normalize(mean=[0.0], std=[1.0], max_pixel_value=255.0), + ToTensorV2()] + + ) + + train_dataset = ProstateDataset(train_image_dir, train_mask_dir, norm_image=True, transform=train_trainsform) + train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=n_workers, pin_memory=True) + val_dataset = ProstateDataset(val_image_dir, val_mask_dir, norm_image=True,transform=train_trainsform) + val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=n_workers, pin_memory=True) + + weights = [1, 1, 1, 2, 6, 3] + weights = torch.tensor(weights, dtype=torch.float).to(device) + model = UNet().to(device=device) + loss_fn = CombinedLoss(ce_weight=weights, cross=False) + scaler = torch.amp.GradScaler(device=device) + optimizer = optim.Adam(model.parameters(), lr = learning_rate) + train_losses = [] + val_dice_scores = [] + val_losses = [] + best_val_loss = float('inf') + + for epoch in range(N_epochs): + print(f"\nEpoch [{epoch+1}/{N_epochs}]") + avg_loss = train_fn(train_loader, model, optimizer, loss_fn, scaler) + train_losses.append(avg_loss) + print(f"Average Training Loss: {avg_loss:.4f}") + + # Validation Phase - Compute Validation Loss + val_loss = validate_fn(val_loader, model, loss_fn) + val_losses.append(val_loss[0]) + val_dice_scores.append(val_loss[1]) + print(f"Average Validation Loss: {val_loss[0]:.4f}") + print(f"Dice Coeff: {val_loss[1]}") + if epoch >= 9: + torch.save(model.state_dict(), f"dice/bestsofarDICE_{epoch}unet2d.pth") + print("Best model saved.") + # Save prediction images + print("Saving best prediction images") + save_img(val_loader, model, folder=f"dice/DICEONLY10epoch_{epoch}", device=device, num_classes=6, epoch=epoch, max_images_per_class=10) + save_img(val_loader, model, folder=f"dice/DICEONLY2epoch_{epoch}", device=device, num_classes=6, epoch=epoch, max_images_per_class=2) + # Plot metrics + plot_metrics(train_losses,val_losses, val_dice_scores, num_classes=6, save_path=f"dice/DICEONLYmetrics_plot_epoch_{epoch}.png") + + + +def plot_metrics(train_losses, val_losses, val_dice_scores, num_classes=6, save_path="metrics_plot.png"): + """ + Plots training and validation losses, and Dice scores per class over epochs. + + Parameters: + train_losses (list): List of training losses per epoch. + val_losses (list): List of validation losses per epoch. + val_dice_scores (list): List of Dice score dictionaries per epoch. + num_classes (int): Number of segmentation classes. + save_path (str): Path to save the plot. + """ + epochs = range(1, len(train_losses) + 1) + + # Create subplots + plt.figure(figsize=(18, 6)) + + # Plot Training and Validation Loss + plt.subplot(1, 2, 1) + plt.plot(epochs, train_losses, 'b-', label='Training Loss') + plt.plot(epochs, val_losses, 'r-', label='Validation Loss') + plt.title('Training and Validation Loss over Epochs') + plt.xlabel('Epoch') + plt.ylabel('Loss') + plt.legend() + + # Plot Dice Scores per Class + plt.subplot(1, 2, 2) + for cls in range(num_classes): + cls_scores = [epoch_dice[cls] for epoch_dice in val_dice_scores] + plt.plot(epochs, cls_scores, label=f'Class {cls}') + plt.title('Dice Score per Class over Epochs') + plt.xlabel('Epoch') + plt.ylabel('Dice Score') + plt.legend() + + plt.tight_layout() + plt.savefig(save_path) + + + +def save_img(loader, model, folder="images", device="cuda", num_classes=6, max_images_per_class=10, epoch=1, classes_per_row=3): + """ + Saves prediction images for each class, stacking multiple classes on the same page. + + Parameters: + loader (DataLoader): DataLoader for the dataset to save images from. + model (nn.Module): Trained model for making predictions. + folder (str): Directory to save the images. + device (str): Device to perform computations on ('cuda' or 'cpu'). + num_classes (int): Number of segmentation classes. + max_images_per_class (int): Maximum number of images to save per class. + epoch (int): Current epoch number for filename reference. + classes_per_row (int): Number of classes to display per row in the stacked image. + """ + model.eval() + os.makedirs(folder, exist_ok=True) + + green = [0, 255, 0] # Green + black = [0, 0, 0] # Black + + # Dictionary to hold images for each class + class_images = {cls: [] for cls in range(num_classes)} + + with torch.no_grad(): + for batch_idx, (x, y) in enumerate(loader): + x = x.to(device=device) + y = y.to(device=device).long() + + # Remove any extra singleton dimensions (e.g., [batch_size, 1, H, W] -> [batch_size, H, W]) + while y.dim() > 3: + y = y.squeeze(1) + + preds = model(x) # Shape: [batch_size, num_classes, H, W] + preds = preds.argmax(dim=1).cpu().numpy() # Shape: [batch_size, H, W] + x = x.cpu().numpy() # Shape: [batch_size, 1, H, W] + y = y.cpu().numpy() # Shape: [batch_size, H, W] + + for i in range(x.shape[0]): + input_img = x[i].squeeze(0) # Shape: [H, W] + input_img = (input_img * 255).astype(np.uint8) # Scale to [0, 255] + + true_mask = y[i] # Shape: [H, W] + pred_mask = preds[i] # Shape: [H, W] + + for cls in range(num_classes): + if len(class_images[cls]) >= max_images_per_class: + continue # Skip if already have enough images for this class + + # Create binary masks for the current class + gt_binary = (true_mask == cls).astype(np.uint8) + pred_binary = (pred_mask == cls).astype(np.uint8) + + # Initialize color masks + gt_color = np.zeros((gt_binary.shape[0], gt_binary.shape[1], 3), dtype=np.uint8) + pred_color = np.zeros_like(gt_color) + + # Assign colors based on binary masks + gt_color[gt_binary == 1] = green + gt_color[gt_binary == 0] = black + + pred_color[pred_binary == 1] = green + pred_color[pred_binary == 0] = black + + # Combine input, ground truth, and predicted masks horizontally + combined = np.hstack((input_img[..., np.newaxis].repeat(3, axis=2), gt_color, pred_color)) + class_images[cls].append(combined) + + if len(class_images[cls]) >= max_images_per_class: + continue # Stop collecting images for this class + + # Now, create stacked images per class group + for row_start in range(0, num_classes, classes_per_row): + classes_in_row = list(range(row_start, min(row_start + classes_per_row, num_classes))) + fig, axs = plt.subplots(1, classes_per_row, figsize=(15, 5)) + + for idx, cls in enumerate(classes_in_row): + if idx >= len(axs): + break # In case num_classes is not a multiple of classes_per_row + + if len(class_images[cls]) == 0: + axs[idx].axis('off') + axs[idx].set_title(f'Class {cls} - No Images') + continue + + # Concatenate images for the class vertically + imgs = class_images[cls] + concatenated = np.vstack(imgs) + axs[idx].imshow(concatenated) + axs[idx].set_title(f'Class {cls}') + axs[idx].axis('off') + + # Hide any unused subplots + for idx in range(len(classes_in_row), classes_per_row): + axs[idx].axis('off') + + plt.tight_layout() + save_path = os.path.join(folder, f"epoch_{epoch}_classes_{row_start}-{row_start + classes_per_row -1}.png") + plt.savefig(save_path) + plt.close(fig) + + print(f"Saved stacked prediction images for epoch {epoch} to '{folder}' folder.") + return + +def validate_fn(loader, model, loss_fn): + """ + Evaluates the model on the validation set. + + Parameters: + loader (DataLoader): Validation data loader. + model (nn.Module): The model to evaluate. + loss_fn (Loss): Loss function. + + Returns: + float: Average validation loss for the epoch. + """ + model.eval() + total_loss = 0.0 + dice_coeffs = [] + with torch.no_grad(): + for batch_idx, (data, targets) in enumerate(loader): + data = data.to(device=device) + targets = targets.float().squeeze(1).to(device=device) + targets = targets.long() + + with torch.cuda.amp.autocast(): + predictions = model(data) + loss = loss_fn(predictions, targets) + + total_loss += loss.item() + # Retrieve Dice coefficients from the loss function + dice_coeff = loss_fn.get_last_dice_coeff() # [num_classes] + dice_coeffs.append(dice_coeff) + + + average_loss = total_loss / len(loader) + average_loss = total_loss / len(loader) + + # Convert list of Dice coefficients to a NumPy array for averaging + dice_coeffs = np.array(dice_coeffs) # Shape: [num_batches, num_classes] + mean_dice_coeff = dice_coeffs.mean(axis=0) # Shape: [num_classes] + + # Create a dictionary for easy interpretation + dice_dict = {cls: mean_dice_coeff[cls] for cls in range(len(mean_dice_coeff))} + return average_loss, dice_dict + + + +if __name__ == '__main__': + main()