diff --git a/recognition/2d_unet_s46974426/README.md b/recognition/2d_unet_s46974426/README.md new file mode 100644 index 0000000000..782bff5975 --- /dev/null +++ b/recognition/2d_unet_s46974426/README.md @@ -0,0 +1,72 @@ +PLEASE NOTE: the version is different from pdf submission because could not re-submit? + +COMP3710 2D UNet Report + +Using 2D UNet to segment the HipMRI Study on Prostate Cancer dataset + +The task for this report was to create 4 files, modules.py, train.py, dataset.py and predict.py to +load and segment the HipMRI study as a 2d Unet using Pytorch and the direct task description taken +from Blackboard is below. + +Task Description from Blackboard: "Segment the HipMRI Study on Prostate Cancer (see Appendix for link) +using the processed 2D slices (2D images) available here with the 2D UNet [1] with all labels having a +minimum Dice similarity coefficient of 0.75 on the test set on the prostate label. You will need to load +Nifti file format and sample code is provided in Appendix B. [Easy Difficulty]" + +I quickly want to mention that I prefixed each commit with 'topic recognition' this was a force of habit, +typically when I have worked on git repositories I first branch the solution named after a change request +e.g. "CR-123" and prefix each commit with the name of the CR. + +An initial test code was run to just visualise one of the slices before using 2D UNet to get a sense of +what the images look like. The resulting image after test.py was run can be seen in +slice_print_from_initial_test in the images folder. + +The data loader was run in a simple for to check that it worked, it was ~50% successful when it errored due +to image sizing issue. To resolve this, an image resizing function was added to be called by the data loader. +The completed data_loader test output can be seen in data_loader_test.png in the images folder. + +After messing around with fixing errors from the original versions I had tried of the modules, dataset, predict +and train scripts I eventually gave up as they would not run. + +I went online and found a similar example of a 2d UNet implemented using pytorch and adapted the code to suit +my problem and reference to this repository can be seen below. + +Author: milesial +Date: 11/02/2024 +Title of program/source code: U-Net: Semantic segmentation with PyTorch +Code version: 475 +Type (e.g. computer program, source code): computer program +Web address or publisher (e.g. program publisher, URL): https://github.com/milesial/Pytorch-UNet + +Also, during this process, I discovered that the masks were in the segment datasets and the images were +in the datasets not suffixed with 'seg' (I had it the wrong way around originally). + +After attempting to run the train.py file and fixing errors as they occurred, I was eventually able to +run the train.py code in full to generate some loss and dice coefficient-based validation plots. + +I ran the train code for the first 5 epochs and a graph showing the batch loss and a graph showing the +dice score can both be seen in the images folder. I then ran it for 50 epochs and the graphs similar to +above are in the images folder. The console running progress can also be seen in the console_running image +in the images folder. + +This final part will outline a description of working principles of the algorithm and the problem it solves. +The Pytorch UNet is comprised of four parts, an encoder, decoder, bottleneck and a convolutional layer. +The modules script contains the UNet’s definition. It also includes the dice coefficient handling to calculate +dice loss which measures the overlap of two images in order to quantify a segmentation model’s accuracy. +I also added a function to combine two datasets (the segment images and masks), this is because datasets +what include both segments and masks are typically used in UNet algorithms. The modules script also included +some basic dataset classes, a method to load images, check uniqueness of masks and some basic plotting logic. + +The train script initialises and loads the UNet model defined in the modules and then trains it on the +segmentation dataset. Before this is done however, it is transformed and loaded as 2d data using the provided +load_data_2d function in the task appendix. The train script handles defining the main train loop, iterating +over the data in batches, calculating losses and dice scores, which are then plotted after the algorithm has +completed. It also handles saving progress while the training loop completes each epoch, which is made up of +a number of batches (typically 5-6 in this case). + +The dataset script just contains the load_data_2d method as seen in the appendix of the task sheet. It also +contains a data transformation function to make the image dimensions consistent. + +Finally, the predict script’s purpose is to generate mask predictions of new images on a trained and saved UNet model. + + diff --git a/recognition/2d_unet_s46974426/dataset.py b/recognition/2d_unet_s46974426/dataset.py new file mode 100644 index 0000000000..7e80b7ae62 --- /dev/null +++ b/recognition/2d_unet_s46974426/dataset.py @@ -0,0 +1,89 @@ +import numpy as np +import nibabel as nib +from tqdm import tqdm +import cv2 +import os # Ensure you have this to work with file paths +import torch +from torch.utils.data import Dataset # Add this import for Dataset + + +''' + Resizes the images to all be consistent (this is required for the data loading process in load_data_2d) + + Parameters: + - image: image that is being resized + - target_shape: shape the image is being resized too e.g. 256*128 pixels +''' +def resize_image(image, target_shape): + """Resize image to the target shape using OpenCV.""" + return cv2.resize(image, (target_shape[1], target_shape[0]), interpolation=cv2.INTER_LINEAR) + +def to_channels(arr: np.ndarray, dtype=np.uint8) -> np.ndarray: + channels = np.unique(arr) + res = np.zeros(arr.shape + (len(channels),), dtype=dtype) + for c in channels: + c = int(c) + res[..., c:c + 1][arr == c] = 1 + return res + +def load_data_2D(imageNames, normImage=False, categorical=False, dtype=np.float32, getAffines=False, early_stop=False): + """ + Load medical image data from names, cases list provided into a list for each. + + Parameters: + - imageNames: List of image file names + - normImage: bool (normalize the image 0.0 - 1.0) + - categorical: bool (indicates if the data is categorical) + - dtype: Desired data type (default: np.float32) + - getAffines: bool (return affine matrices along with images) + - early_stop: bool (stop loading prematurely for testing purposes) + + Returns: + - images: Loaded image data as a numpy array + - affines: List of affine matrices (if getAffines is True) + """ + affines = [] + num = len(imageNames) + first_case = nib.load(imageNames[0]).get_fdata(caching='unchanged') + + if len(first_case.shape) == 3: + first_case = first_case[:, :, 0] # Remove extra dims if necessary + if categorical: + first_case = to_channels(first_case, dtype=dtype) + rows, cols, channels = first_case.shape + images = np.zeros((num, rows, cols, channels), dtype=dtype) + else: + rows, cols = first_case.shape + images = np.zeros((num, rows, cols), dtype=dtype) + + for i, inName in enumerate(tqdm(imageNames)): + niftiImage = nib.load(inName) + inImage = niftiImage.get_fdata(caching='unchanged') + affine = niftiImage.affine + + if len(inImage.shape) == 3: + inImage = inImage[:, :, 0] # Remove extra dims if necessary + inImage = inImage.astype(dtype) + + # Resize the image if necessary + if inImage.shape != (rows, cols): + inImage = resize_image(inImage, (rows, cols)) + + if normImage: + inImage = (inImage - inImage.mean()) / inImage.std() + + if categorical: + inImage = to_channels(inImage, dtype=dtype) + images[i, :, :, :] = inImage + else: + images[i, :, :] = inImage + + affines.append(affine) + if i > 20 and early_stop: + break + + if getAffines: + return images, affines + else: + return images + diff --git a/recognition/2d_unet_s46974426/images/50Epoc_validation_score.png b/recognition/2d_unet_s46974426/images/50Epoc_validation_score.png new file mode 100644 index 0000000000..a30b776890 Binary files /dev/null and b/recognition/2d_unet_s46974426/images/50Epoc_validation_score.png differ diff --git a/recognition/2d_unet_s46974426/images/50Epoch_batch_loss.png b/recognition/2d_unet_s46974426/images/50Epoch_batch_loss.png new file mode 100644 index 0000000000..70f7824c1f Binary files /dev/null and b/recognition/2d_unet_s46974426/images/50Epoch_batch_loss.png differ diff --git a/recognition/2d_unet_s46974426/images/5Epochs_batch_loss.png b/recognition/2d_unet_s46974426/images/5Epochs_batch_loss.png new file mode 100644 index 0000000000..79d949eba5 Binary files /dev/null and b/recognition/2d_unet_s46974426/images/5Epochs_batch_loss.png differ diff --git a/recognition/2d_unet_s46974426/images/5Epochs_dice_score.png b/recognition/2d_unet_s46974426/images/5Epochs_dice_score.png new file mode 100644 index 0000000000..b773765285 Binary files /dev/null and b/recognition/2d_unet_s46974426/images/5Epochs_dice_score.png differ diff --git a/recognition/2d_unet_s46974426/images/console_running.png b/recognition/2d_unet_s46974426/images/console_running.png new file mode 100644 index 0000000000..cfcce15b94 Binary files /dev/null and b/recognition/2d_unet_s46974426/images/console_running.png differ diff --git a/recognition/2d_unet_s46974426/images/data_loader_test.png b/recognition/2d_unet_s46974426/images/data_loader_test.png new file mode 100644 index 0000000000..46be4f5af6 Binary files /dev/null and b/recognition/2d_unet_s46974426/images/data_loader_test.png differ diff --git a/recognition/2d_unet_s46974426/images/slice_print_from_initial_test.png b/recognition/2d_unet_s46974426/images/slice_print_from_initial_test.png new file mode 100644 index 0000000000..b01fac9b74 Binary files /dev/null and b/recognition/2d_unet_s46974426/images/slice_print_from_initial_test.png differ diff --git a/recognition/2d_unet_s46974426/modules.py b/recognition/2d_unet_s46974426/modules.py new file mode 100644 index 0000000000..16a3638aca --- /dev/null +++ b/recognition/2d_unet_s46974426/modules.py @@ -0,0 +1,372 @@ +"""https://github.com/milesial/Pytorch-UNet/blob/master/unet/unet_model.py""" + +""" Parts of the U-Net model """ + +import logging +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor +from torch.utils.data import Dataset +import numpy as np +from pathlib import Path +from PIL import Image +from os.path import splitext, isfile, join +from os import listdir +from tqdm import tqdm +from functools import partial +from multiprocessing import Pool +import matplotlib.pyplot as plt + +class DoubleConv(nn.Module): + """(convolution => [BN] => ReLU) * 2""" + + def __init__(self, in_channels, out_channels, mid_channels=None): + super().__init__() + if not mid_channels: + mid_channels = out_channels + self.double_conv = nn.Sequential( + nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False), + nn.BatchNorm2d(mid_channels), + nn.ReLU(inplace=True), + nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False), + nn.BatchNorm2d(out_channels), + nn.ReLU(inplace=True) + ) + + def forward(self, x): + return self.double_conv(x) + + +class Down(nn.Module): + """Downscaling with maxpool then double conv""" + + def __init__(self, in_channels, out_channels): + super().__init__() + self.maxpool_conv = nn.Sequential( + nn.MaxPool2d(2), + DoubleConv(in_channels, out_channels) + ) + + def forward(self, x): + return self.maxpool_conv(x) + + +class Up(nn.Module): + """Upscaling then double conv""" + + def __init__(self, in_channels, out_channels, bilinear=True): + super().__init__() + + # if bilinear, use the normal convolutions to reduce the number of channels + if bilinear: + self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) + self.conv = DoubleConv(in_channels, out_channels, in_channels // 2) + else: + self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2) + self.conv = DoubleConv(in_channels, out_channels) + + def forward(self, x1, x2): + x1 = self.up(x1) + # input is CHW + diffY = x2.size()[2] - x1.size()[2] + diffX = x2.size()[3] - x1.size()[3] + + x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, + diffY // 2, diffY - diffY // 2]) + # if you have padding issues, see + # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a + # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd + x = torch.cat([x2, x1], dim=1) + return self.conv(x) + +''' + Definition of a OutConv class +''' +class OutConv(nn.Module): + def __init__(self, in_channels, out_channels): + super(OutConv, self).__init__() + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) + + def forward(self, x): + return self.conv(x) + +''' + Definition of a Pytorch UNet to be trained on the loaded 2d segmentation dataset +''' +class UNet(nn.Module): + def __init__(self, n_channels, n_classes, bilinear=False): + super(UNet, self).__init__() + self.n_channels = n_channels + self.n_classes = n_classes + self.bilinear = bilinear + + self.inc = (DoubleConv(n_channels, 64)) + self.down1 = (Down(64, 128)) + self.down2 = (Down(128, 256)) + self.down3 = (Down(256, 512)) + factor = 2 if bilinear else 1 + self.down4 = (Down(512, 1024 // factor)) + self.up1 = (Up(1024, 512 // factor, bilinear)) + self.up2 = (Up(512, 256 // factor, bilinear)) + self.up3 = (Up(256, 128 // factor, bilinear)) + self.up4 = (Up(128, 64, bilinear)) + self.outc = (OutConv(64, n_classes)) + + def forward(self, x): + x1 = self.inc(x) + x2 = self.down1(x1) + x3 = self.down2(x2) + x4 = self.down3(x3) + x5 = self.down4(x4) + x = self.up1(x5, x4) + x = self.up2(x, x3) + x = self.up3(x, x2) + x = self.up4(x, x1) + logits = self.outc(x) + return logits + + def use_checkpointing(self): + self.inc = torch.utils.checkpoint(self.inc) + self.down1 = torch.utils.checkpoint(self.down1) + self.down2 = torch.utils.checkpoint(self.down2) + self.down3 = torch.utils.checkpoint(self.down3) + self.down4 = torch.utils.checkpoint(self.down4) + self.up1 = torch.utils.checkpoint(self.up1) + self.up2 = torch.utils.checkpoint(self.up2) + self.up3 = torch.utils.checkpoint(self.up3) + self.up4 = torch.utils.checkpoint(self.up4) + self.outc = torch.utils.checkpoint(self.outc) + +''' + Handles calculating dice coefficent between a predicted mask and a true mask + + Parameters: + -input: predicted mask (as a tensor) + -target: true mask (as a tensor) + -reduce_batch_first: wheater to average over batch dimension + -epsilon: small value to prevent division by 0 +''' +def dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon: float = 1e-6): + # Average of Dice coefficient for all batches, or for a single mask + assert input.size() == target.size() + assert input.dim() == 3 or not reduce_batch_first + + sum_dim = (-1, -2) if input.dim() == 2 or not reduce_batch_first else (-1, -2, -3) + + inter = 2 * (input * target).sum(dim=sum_dim) + sets_sum = input.sum(dim=sum_dim) + target.sum(dim=sum_dim) + sets_sum = torch.where(sets_sum == 0, inter, sets_sum) + + dice = (inter + epsilon) / (sets_sum + epsilon) + return dice.mean() + +''' + Calculates dice coefficient across multiple classes + + Parameters: + -input: predicted mask (as a tensor) + -target: true mask (as a tensor) + -reduce_batch_first: wheater to average over batch dimension + -epsilon: small value to prevent division by 0 +''' +def multiclass_dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon: float = 1e-6): + # Average of Dice coefficient for all classes + return dice_coeff(input.flatten(0, 1), target.flatten(0, 1), reduce_batch_first, epsilon) + +''' + Calculates dice loss for an image segmentation + + Parameters: + -input: predicted mask (as a tensor) + -target: true mask (as a tensor) + -multiclass: is multiclass? +''' +def dice_loss(input: Tensor, target: Tensor, multiclass: bool = False): + # Dice loss (objective to minimize) between 0 and 1 + fn = multiclass_dice_coeff if multiclass else dice_coeff + return 1 - fn(input, target, reduce_batch_first=True) + +''' + Combines two datasets, in this case used to combine the mask and segmetation datasets + + Initialisation: + -takes two datasets to intialise the class and returns them in a tuple +''' +class CombinedDataset(Dataset): + def __init__(self, images, image_masks, transform=None): + """ + Args: + images (numpy array): The array containing image data. + image_masks (numpy array): The array containing corresponding mask data. + transform (callable, optional): Optional transform to be applied on a sample. + """ + self.images = images + self.image_masks = image_masks + self.transform = transform + + def __len__(self): + # Return the number of samples in the dataset (should be the same for images and masks) + return len(self.images) + + def __getitem__(self, idx): + # Load the image and the corresponding mask + image = self.images[idx] + mask = self.image_masks[idx] + + # Add a channel dimension if the images are grayscale (i.e., single-channel) + if len(image.shape) == 2: # If the image is HxW (no channel dimension) + image = np.expand_dims(image, axis=0) # Add channel dimension -> (1, H, W) + + if len(mask.shape) == 2: # If the mask is HxW (no channel dimension) + mask = np.expand_dims(mask, axis=0) # Add channel dimension -> (1, H, W) + + # Apply any transformation if provided + if self.transform: + image = self.transform(image) + mask = self.transform(mask) + + # Convert to PyTorch tensors + image = torch.from_numpy(image).float() + mask = torch.from_numpy(mask).float() + + return image, mask + +''' + Loads a specified image + + Parameters: + -filename: path to the image being loaded +''' +def load_image(filename): + ext = splitext(filename)[1] + if ext == '.npy': + return Image.fromarray(np.load(filename)) + elif ext in ['.pt', '.pth']: + return Image.fromarray(torch.load(filename).numpy()) + else: + return Image.open(filename) + +''' + Identifies unique values in a mask + + Parameters: + -idx: string identifier of a mask + -mask_dir: directory contianing the masks + -mask_suffix: identifies the mask by combining this and the idx +''' +def unique_mask_values(idx, mask_dir, mask_suffix): + mask_file = list(mask_dir.glob(idx + mask_suffix + '.*'))[0] + mask = np.asarray(load_image(mask_file)) + if mask.ndim == 2: + return np.unique(mask) + elif mask.ndim == 3: + mask = mask.reshape(-1, mask.shape[-1]) + return np.unique(mask, axis=0) + else: + raise ValueError(f'Loaded masks should have 2 or 3 dimensions, found {mask.ndim}') + +''' + Defines a BasicDatabase class +''' +class BasicDataset(Dataset): + def __init__(self, images_dir: str, mask_dir: str, scale: float = 1.0, mask_suffix: str = ''): + self.images_dir = Path(images_dir) + self.mask_dir = Path(mask_dir) + assert 0 < scale <= 1, 'Scale must be between 0 and 1' + self.scale = scale + self.mask_suffix = mask_suffix + + self.ids = [splitext(file)[0] for file in listdir(images_dir) if isfile(join(images_dir, file)) and not file.startswith('.')] + if not self.ids: + raise RuntimeError(f'No input file found in {images_dir}, make sure you put your images there') + + logging.info(f'Creating dataset with {len(self.ids)} examples') + logging.info('Scanning mask files to determine unique values') + with Pool() as p: + unique = list(tqdm( + p.imap(partial(unique_mask_values, mask_dir=self.mask_dir, mask_suffix=self.mask_suffix), self.ids), + total=len(self.ids) + )) + + self.mask_values = list(sorted(np.unique(np.concatenate(unique), axis=0).tolist())) + logging.info(f'Unique mask values: {self.mask_values}') + + def __len__(self): + return len(self.ids) + + @staticmethod + def preprocess(mask_values, pil_img, scale, is_mask): + w, h = pil_img.size + newW, newH = int(scale * w), int(scale * h) + assert newW > 0 and newH > 0, 'Scale is too small, resized images would have no pixel' + pil_img = pil_img.resize((newW, newH), resample=Image.NEAREST if is_mask else Image.BICUBIC) + img = np.asarray(pil_img) + + if is_mask: + mask = np.zeros((newH, newW), dtype=np.int64) + for i, v in enumerate(mask_values): + if img.ndim == 2: + mask[img == v] = i + else: + mask[(img == v).all(-1)] = i + + return mask + + else: + if img.ndim == 2: + img = img[np.newaxis, ...] + else: + img = img.transpose((2, 0, 1)) + + if (img > 1).any(): + img = img / 255.0 + + return img + + def __getitem__(self, idx): + name = self.ids[idx] + mask_file = list(self.mask_dir.glob(name + self.mask_suffix + '.*')) + img_file = list(self.images_dir.glob(name + '.*')) + + assert len(img_file) == 1, f'Either no image or multiple images found for the ID {name}: {img_file}' + assert len(mask_file) == 1, f'Either no mask or multiple masks found for the ID {name}: {mask_file}' + mask = load_image(mask_file[0]) + img = load_image(img_file[0]) + + assert img.size == mask.size, \ + f'Image and mask {name} should be the same size, but are {img.size} and {mask.size}' + + img = self.preprocess(self.mask_values, img, self.scale, is_mask=False) + mask = self.preprocess(self.mask_values, mask, self.scale, is_mask=True) + + return { + 'image': torch.as_tensor(img.copy()).float().contiguous(), + 'mask': torch.as_tensor(mask.copy()).long().contiguous() + } + + +''' + Defines a CarvanaDataset class +''' +class CarvanaDataset(BasicDataset): + def __init__(self, images_dir, mask_dir, scale=1): + super().__init__(images_dir, mask_dir, scale, mask_suffix='_mask') + +''' + Funtion to plot image and its corrolated mask + + Paramters: + -img: image + -mask: mask +''' +def plot_img_and_mask(img, mask): + classes = mask.max() + 1 + fig, ax = plt.subplots(1, classes + 1) + ax[0].set_title('Input image') + ax[0].imshow(img) + for i in range(classes): + ax[i + 1].set_title(f'Mask (class {i + 1})') + ax[i + 1].imshow(mask == i) + plt.xticks([]), plt.yticks([]) + plt.show() \ No newline at end of file diff --git a/recognition/2d_unet_s46974426/predict.py b/recognition/2d_unet_s46974426/predict.py new file mode 100644 index 0000000000..8255e66794 --- /dev/null +++ b/recognition/2d_unet_s46974426/predict.py @@ -0,0 +1,118 @@ +import argparse +import logging +import os + +import numpy as np +import torch +import torch.nn.functional as F +from PIL import Image +from torchvision import transforms + +from modules import BasicDataset, UNet, plot_img_and_mask + +''' + This script is unaltered from the example UNet usage (referenced in the report) as I was not super sure what to do with it and ran out of time to implement for this segmentation example. +''' +def predict_img(net, + full_img, + device, + scale_factor=1, + out_threshold=0.5): + net.eval() + img = torch.from_numpy(BasicDataset.preprocess(None, full_img, scale_factor, is_mask=False)) + img = img.unsqueeze(0) + img = img.to(device=device, dtype=torch.float32) + + with torch.no_grad(): + output = net(img).cpu() + output = F.interpolate(output, (full_img.size[1], full_img.size[0]), mode='bilinear') + if net.n_classes > 1: + mask = output.argmax(dim=1) + else: + mask = torch.sigmoid(output) > out_threshold + + return mask[0].long().squeeze().numpy() + + +def get_args(): + parser = argparse.ArgumentParser(description='Predict masks from input images') + parser.add_argument('--model', '-m', default='MODEL.pth', metavar='FILE', + help='Specify the file in which the model is stored') + parser.add_argument('--input', '-i', metavar='INPUT', nargs='+', help='Filenames of input images', required=True) + parser.add_argument('--output', '-o', metavar='OUTPUT', nargs='+', help='Filenames of output images') + parser.add_argument('--viz', '-v', action='store_true', + help='Visualize the images as they are processed') + parser.add_argument('--no-save', '-n', action='store_true', help='Do not save the output masks') + parser.add_argument('--mask-threshold', '-t', type=float, default=0.5, + help='Minimum probability value to consider a mask pixel white') + parser.add_argument('--scale', '-s', type=float, default=0.5, + help='Scale factor for the input images') + parser.add_argument('--bilinear', action='store_true', default=False, help='Use bilinear upsampling') + parser.add_argument('--classes', '-c', type=int, default=2, help='Number of classes') + + return parser.parse_args() + + +def get_output_filenames(args): + def _generate_name(fn): + return f'{os.path.splitext(fn)[0]}_OUT.png' + + return args.output or list(map(_generate_name, args.input)) + + +def mask_to_image(mask: np.ndarray, mask_values): + if isinstance(mask_values[0], list): + out = np.zeros((mask.shape[-2], mask.shape[-1], len(mask_values[0])), dtype=np.uint8) + elif mask_values == [0, 1]: + out = np.zeros((mask.shape[-2], mask.shape[-1]), dtype=bool) + else: + out = np.zeros((mask.shape[-2], mask.shape[-1]), dtype=np.uint8) + + if mask.ndim == 3: + mask = np.argmax(mask, axis=0) + + for i, v in enumerate(mask_values): + out[mask == i] = v + + return Image.fromarray(out) + + +if __name__ == '__main__': + args = get_args() + logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s') + + in_files = args.input + out_files = get_output_filenames(args) + + net = UNet(n_channels=3, n_classes=args.classes, bilinear=args.bilinear) + + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + logging.info(f'Loading model {args.model}') + logging.info(f'Using device {device}') + + net.to(device=device) + state_dict = torch.load(args.model, map_location=device) + mask_values = state_dict.pop('mask_values', [0, 1]) + net.load_state_dict(state_dict) + + logging.info('Model loaded!') + + for i, filename in enumerate(in_files): + logging.info(f'Predicting image {filename} ...') + img = Image.open(filename) + + mask = predict_img(net=net, + full_img=img, + scale_factor=args.scale, + out_threshold=args.mask_threshold, + device=device) + + if not args.no_save: + out_filename = out_files[i] + result = mask_to_image(mask, mask_values) + result.save(out_filename) + logging.info(f'Mask saved to {out_filename}') + + if args.viz: + logging.info(f'Visualizing results for image {filename}, close to continue...') + plot_img_and_mask(img, mask) \ No newline at end of file diff --git a/recognition/2d_unet_s46974426/train.py b/recognition/2d_unet_s46974426/train.py new file mode 100644 index 0000000000..e06782defe --- /dev/null +++ b/recognition/2d_unet_s46974426/train.py @@ -0,0 +1,372 @@ +import argparse +import logging +import os +import random +import sys +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision.transforms as transforms +import torchvision.transforms.functional as TF +from pathlib import Path +from torch import optim +from torch.utils.data import DataLoader, random_split +from tqdm import tqdm +from modules import dice_coeff, multiclass_dice_coeff, UNet, CombinedDataset, dice_loss +from dataset import load_data_2D +import numpy as np +import matplotlib.pyplot as plt +from sklearn.metrics import confusion_matrix +import seaborn as sns + +import wandb + +#setting dataset paths and empty arrays for plotting +dir_img = Path('C:/Users/rober/Desktop/COMP3710/keras_slices_test') +dir_mask = Path('C:/Users/rober/Desktop/COMP3710/keras_slices_seg_test') +dir_img_val = Path('C:/Users/rober/Desktop/COMP3710/keras_slices_validate') +dir_mask_val = Path('C:/Users/rober/Desktop/COMP3710/keras_slices_seg_validate') +dir_checkpoint = Path('./checkpoints') + +batch_losses = [] +val_dice_scores = [] +conf_matrix_total = None + +''' + Computes the dice score of the UNet, a measure of how well the model segments images by comparing the predicted and true masks + + Parameters: + -net: the UNet + -dataloader: dataloader that provides validation data in batches + -device: the device where the model is run on (in this case cuda) + -amp: depricated from reference code +''' +def evaluate(net, dataloader, device, amp): + net.eval() + num_val_batches = len(dataloader) + dice_score = 0 + + # iterate over the validation set + with torch.autocast(device_type = 'cuda'): + for batch in tqdm(dataloader, total=num_val_batches, desc='Validation round', unit='batch', leave=False): + image, mask_true = batch + image = image.to(device=device, dtype=torch.float32, memory_format=torch.channels_last) + mask_true = mask_true.to(device=device, dtype=torch.long).squeeze(1) + mask_true = torch.clamp(mask_true, min=0, max=1) + + mask_pred = net(image) + + if net.n_classes == 1: + mask_pred = (F.sigmoid(mask_pred) > 0.5).float() + dice_score += dice_coeff(mask_pred, mask_true, reduce_batch_first=False) + + # Confusion matrix for binary classification + # preds_flat = mask_pred.view(-1).cpu().numpy() + # labels_flat = mask_true.view(-1).cpu().numpy() + # conf_matrix = confusion_matrix(labels_flat, preds_flat) + + else: + mask_true = F.one_hot(mask_true, net.n_classes).permute(0, 3, 1, 2).float() + mask_pred = F.one_hot(mask_pred.argmax(dim=1), net.n_classes).permute(0, 3, 1, 2).float() + dice_score += multiclass_dice_coeff(mask_pred[:, 1:], mask_true[:, 1:], reduce_batch_first=False) + + # Confusion matrix for multi-class classification + # preds_flat = mask_pred.argmax(dim=1).view(-1).cpu().numpy() + # labels_flat = mask_true.argmax(dim=1).view(-1).cpu().numpy() + # conf_matrix = confusion_matrix(labels_flat, preds_flat) + + # if conf_matrix_total is None: + # conf_matrix_total = conf_matrix + # else: + # conf_matrix_total += conf_matrix + + net.train() + + return dice_score / max(num_val_batches, 1) + +''' + Main training model for training the segmentation model (UNet) + + Parameters: + -model: the UNet model + -device: device the model is run on (cuda) + -epochs: the number of epochs run + -batch_size: the size of the batch + -learning_rate: the learning rate of the model + -save_checkpoint: boolean to determine if the model is saves based on progress + -img_scale: scaling factor for images (default seems to be typically 0,5) + -amp: boolean to determine use of automatic mixed precision (amp) + -weight_decay: regularisation parameter to prevent overfitting + -momentum: hyperparam for optimiser to accelerate SGD + -gradient_clipping: clips back progogation to prevent exploding gradients +''' +def train_model( + model, + device, + epochs: int = 50, + batch_size: int = 1, + learning_rate: float = 1e-5, + val_percent: float = 0.1, + save_checkpoint: bool = True, + img_scale: float = 0.5, + amp: bool = False, + weight_decay: float = 1e-8, + momentum: float = 0.999, + gradient_clipping: float = 1.0, +): + image_files = [os.path.join(dir_img, f) for f in os.listdir(dir_img) if f.endswith('.nii.gz') or f.endswith('.nii')] + # Step 2: Load the images using load_data_2D + images = load_data_2D(image_files, normImage=False, categorical=False, dtype=np.float32, getAffines=False, early_stop=False) + + image_files_mask = [os.path.join(dir_mask, f) for f in os.listdir(dir_mask) if f.endswith('.nii.gz') or f.endswith('.nii')] + # Step 2: Load the images using load_data_2D + images_mask = load_data_2D(image_files_mask, normImage=False, categorical=False, dtype=np.float32, getAffines=False, early_stop=False) + + image_files_val = [os.path.join(dir_img_val, f) for f in os.listdir(dir_img_val) if f.endswith('.nii.gz') or f.endswith('.nii')] + # Step 2: Load the images using load_data_2D + images_val = load_data_2D(image_files_val, normImage=False, categorical=False, dtype=np.float32, getAffines=False, early_stop=False) + + image_files_mask_val = [os.path.join(dir_mask_val, f) for f in os.listdir(dir_mask_val) if f.endswith('.nii.gz') or f.endswith('.nii')] + # Step 2: Load the images using load_data_2D + images_mask_val = load_data_2D(image_files_mask_val, normImage=False, categorical=False, dtype=np.float32, getAffines=False, early_stop=False) + + + training_set = CombinedDataset(images, images_mask) + validate_set = CombinedDataset(images_val, images_mask_val) + # 2. Split into train / validation partitions + n_val = int(len(images) * val_percent) + n_train = len(images) - n_val + train_set, val_set = training_set, validate_set + print(len(train_set)) + + train_loader = DataLoader(train_set, shuffle=True) + val_loader = DataLoader(val_set, shuffle=False, drop_last=True) + + # (Initialize logging) + # experiment = wandb.init(project='U-Net', resume='allow', anonymous='must') + # experiment.config.update( + # dict(epochs=epochs, batch_size=batch_size, learning_rate=learning_rate, + # val_percent=val_percent, save_checkpoint=save_checkpoint, img_scale=img_scale, amp=amp) + # ) + + logging.info(f'''Starting training: + Epochs: {epochs} + Batch size: {batch_size} + Learning rate: {learning_rate} + Training size: {n_train} + Validation size: {n_val} + Checkpoints: {save_checkpoint} + Device: {device.type} + Images scaling: {img_scale} + Mixed Precision: {amp} + ''') + + # 4. Set up the optimizer, the loss, the learning rate scheduler and the loss scaling for AMP + optimizer = optim.RMSprop(model.parameters(), + lr=learning_rate, weight_decay=weight_decay, momentum=momentum, foreach=True) + scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=5) # goal: maximize Dice score + grad_scaler = torch.cuda.amp.GradScaler(enabled=amp) + criterion = nn.CrossEntropyLoss() if model.n_classes > 1 else nn.BCEWithLogitsLoss() + global_step = 0 + + # 5. Begin training + for epoch in range(1, epochs + 1): + model.train() + epoch_loss = 0 + with tqdm(total=n_train, desc=f'Epoch {epoch}/{epochs}', unit='img') as pbar: + for batch in train_loader: + + images, true_masks = batch + #print("masks ", true_masks.size()) + true_masks = true_masks.squeeze(1) + true_masks = torch.clamp(true_masks, min=0, max=1) + + #print("hello: ", images.size()) + + assert images.shape[1] == model.n_channels, \ + f'Network has been defined with {model.n_channels} input channels, ' \ + f'but loaded images have {images.shape[1]} channels. Please check that ' \ + 'the images are loaded correctly.' + + images = images.to(device=device, dtype=torch.float32, memory_format=torch.channels_last) + true_masks = true_masks.to(device=device, dtype=torch.long) + + with torch.autocast(device.type if device.type != 'mps' else 'cpu', enabled=amp): + masks_pred = model(images) + if model.n_classes == 1: + loss = criterion(masks_pred.squeeze(1), true_masks.float()) + loss += dice_loss(F.sigmoid(masks_pred.squeeze(1)), true_masks.float(), multiclass=False) + else: + loss = criterion(masks_pred, true_masks) + loss += dice_loss( + F.softmax(masks_pred, dim=1).float(), + F.one_hot(true_masks, model.n_classes).permute(0, 3, 1, 2).float(), + multiclass=True + ) + + optimizer.zero_grad(set_to_none=True) + grad_scaler.scale(loss).backward() + grad_scaler.unscale_(optimizer) + torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clipping) + grad_scaler.step(optimizer) + grad_scaler.update() + + pbar.update(images.shape[0]) + global_step += 1 + epoch_loss += loss.item() + # experiment.log({ + # 'train loss': loss.item(), + # 'step': global_step, + # 'epoch': epoch + # }) + pbar.set_postfix(**{'loss (batch)': loss.item()}) + batch_losses.append(loss.item()) + # Evaluation round + division_step = (n_train // (5 * batch_size)) + if division_step > 0: + if global_step % division_step == 0: + histograms = {} + for tag, value in model.named_parameters(): + tag = tag.replace('/', '.') + if not (torch.isinf(value) | torch.isnan(value)).any(): + histograms['Weights/' + tag] = wandb.Histogram(value.data.cpu()) + if not (torch.isinf(value.grad) | torch.isnan(value.grad)).any(): + histograms['Gradients/' + tag] = wandb.Histogram(value.grad.data.cpu()) + + val_score = evaluate(model, val_loader, device, amp) + scheduler.step(val_score) + + val_dice_scores.append(val_score) + logging.info('Validation Dice score: {}'.format(val_score)) + try: + pass + # experiment.log({ + # 'learning rate': optimizer.param_groups[0]['lr'], + # 'validation Dice': val_score, + # 'images': wandb.Image(images[0].cpu()), + # 'masks': { + # 'true': wandb.Image(true_masks[0].float().cpu()), + # 'pred': wandb.Image(masks_pred.argmax(dim=1)[0].float().cpu()), + # }, + # 'step': global_step, + # 'epoch': epoch, + # **histograms + # }) + except: + pass + + if save_checkpoint: + Path(dir_checkpoint).mkdir(parents=True, exist_ok=True) + state_dict = model.state_dict() + state_dict['mask_values'] = training_set.image_masks + torch.save(state_dict, str(dir_checkpoint / 'checkpoint_epoch{}.pth'.format(epoch))) + logging.info(f'Checkpoint {epoch} saved!') + + plt.figure(figsize=(10, 5)) + plt.plot(batch_losses, label='Batch Loss') + plt.title('Batch Loss During Training') + plt.xlabel('Batch') + plt.ylabel('Loss') + plt.legend() + plt.show() + + val_dice_scores_cpu = [score.cpu().item() for score in val_dice_scores] + # Plot validation Dice scores + plt.figure(figsize=(10, 5)) + plt.plot(val_dice_scores_cpu, label='Validation Dice Score') + plt.title('Validation Dice Score') + plt.xlabel('Epoch') + plt.ylabel('Dice Score') + plt.legend() + plt.show() + + # # Confusion Matrix Plot + # plt.figure(figsize=(8, 6)) + # sns.heatmap(conf_matrix_total, annot=True, fmt='d', cmap='Blues') + # plt.title('Confusion Matrix') + # plt.xlabel('Predicted') + # plt.ylabel('True') + # plt.show() + +''' + Command-line interface to train the UNet model on images and corresponding target masks + + Parameters: + --epochs: number of epochs run + --batch-size: size of each batch + --learning-rate: learning rate of model + --load: specifies path to pre-trained model + --scale: downscaling factor + --validation: dataset percentage used for validation + --amp: whether to use automatic mixed precision (amp) + --bilinear: whether to use bilinear upsampling + --classes: number of classes for the segmentation task +''' +def get_args(): + parser = argparse.ArgumentParser(description='Train the UNet on images and target masks') + parser.add_argument('--epochs', '-e', metavar='E', type=int, default=50, help='Number of epochs') + parser.add_argument('--batch-size', '-b', dest='batch_size', metavar='B', type=int, default=1, help='Batch size') + parser.add_argument('--learning-rate', '-l', metavar='LR', type=float, default=1e-5, + help='Learning rate', dest='lr') + parser.add_argument('--load', '-f', type=str, default=False, help='Load model from a .pth file') + parser.add_argument('--scale', '-s', type=float, default=0.5, help='Downscaling factor of the images') + parser.add_argument('--validation', '-v', dest='val', type=float, default=10.0, + help='Percent of the data that is used as validation (0-100)') + parser.add_argument('--amp', action='store_true', default=False, help='Use mixed precision') + parser.add_argument('--bilinear', action='store_true', default=False, help='Use bilinear upksampling') + parser.add_argument('--classes', '-c', type=int, default=2, help='Number of classes') + + return parser.parse_args() + +args = get_args() +logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s') +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +logging.info(f'Using device {device}') + +# Change here to adapt to your data +# n_channels=3 for RGB images +# n_classes is the number of probabilities you want to get per pixel +model = UNet(n_channels=1, n_classes=2, bilinear=args.bilinear) +print("hi: ", args.classes) +model = model.to(memory_format=torch.channels_last) + +logging.info(f'Network:\n' + f'\t{model.n_channels} input channels\n' + f'\t{model.n_classes} output channels (classes)\n' + f'\t{"Bilinear" if model.bilinear else "Transposed conv"} upscaling') + +# if args.load: +# state_dict = torch.load(args.load, map_location=device) +# del state_dict['mask_values'] +# model.load_state_dict(state_dict) +# logging.info(f'Model loaded from {args.load}') + +model.to(device=device) +try: + train_model( + model=model, + epochs=args.epochs, + batch_size=args.batch_size, + learning_rate=args.lr, + device=device, + img_scale=args.scale, + val_percent=args.val / 100, + amp=args.amp + ) +except torch.cuda.OutOfMemoryError: + logging.error('Detected OutOfMemoryError! ' + 'Enabling checkpointing to reduce memory usage, but this slows down training. ' + 'Consider enabling AMP (--amp) for fast and memory efficient training') + torch.cuda.empty_cache() + model.use_checkpointing() + train_model( + model=model, + epochs=args.epochs, + batch_size=args.batch_size, + learning_rate=args.lr, + device=device, + img_scale=args.scale, + val_percent=args.val / 100, + amp=args.amp + ) +