Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

2D-UNet Image Segmentation of HipMRI slices #157

Open
wants to merge 18 commits into
base: topic-recognition
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
9bd2328
Added functions to read and load nifti files
LiftedFordRanger Oct 11, 2024
0ac7432
Added loader for train, test and validate for segmentation. Scales
LiftedFordRanger Oct 12, 2024
3001366
Fixed one-hot encoding of labels and scaling of input images
LiftedFordRanger Oct 12, 2024
14b12a8
Fixed one_hot encoding of labels
LiftedFordRanger Oct 13, 2024
99ca0f6
Added modules.py with Unet model, moved files into folder
LiftedFordRanger Oct 15, 2024
4bf846c
Added train.py - model learning loop
LiftedFordRanger Oct 22, 2024
7d41827
Fixed number of filters in segmentation layers of model to reflect nu…
LiftedFordRanger Oct 23, 2024
4a0ed4f
Added data normalisation, fixed scaling of labels when resizing input
LiftedFordRanger Oct 23, 2024
6b9bcf7
Moved DataGenerator to dataset.py
LiftedFordRanger Oct 24, 2024
d692c94
Added predict.py for model evaluation. Moved some data loading from
LiftedFordRanger Oct 26, 2024
45b4f28
Added comments and removed commented out code and unnecessary prints
LiftedFordRanger Oct 27, 2024
3e9b851
Added README
LiftedFordRanger Oct 27, 2024
04d5181
Added folder with graphs of prediction and loss over iterations
LiftedFordRanger Oct 27, 2024
43d5183
Fixed images in README
LiftedFordRanger Oct 27, 2024
fc8408d
Improved placement of images in README
LiftedFordRanger Oct 27, 2024
c5b3ac7
Fixed title image in README
LiftedFordRanger Oct 27, 2024
038895e
Added UNet image to graphs folder for README
LiftedFordRanger Oct 27, 2024
621efb8
Added Header blocks to code files
LiftedFordRanger Oct 27, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions recognition/2D_UNet_46991638/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Segmenting Hip-MRI to identify Prostate Cancer using 2D-UNet
Every year tens of thousands of men are diagnosed with prostate cancer, mostly effecting elderly men[1]. The identification of abnormal or enlarged prostates in a HipMRI could therefore be useful for assisting doctors in early detection.

## 2D UNet
![alt text](graphs/UNet.png)[2]

The UNet is a model mostly used for segmenting images and is especially effective on small datasets, which are common in medical imaging[3]. The model consists of an encoding network which extracts features from the input followed by a decoding network which creates the segmentation mask[2]. A key difference between the UNet earlier image segmentation models is that the UNet utilises skip connections which allow for finer details in the original image to be carried through the model into the segmentation mask.

## Implementation
This model is design to be used on 2D slices of HipMRI data which can be found [here](https://data.csiro.au/collection/csiro:51392v2). The dataset was preprocesses into 2D-Slices. Some images were of different sizes so each image was normalised to a size of 256x256 and the pixel values normalised 0-1. The preprocessing already separated the data into train, validate and test sets which were of size 11640, 660 and 540 respectively. Using this distribution aims to maximise the available data for training while still providing enough data for testing and validation. The model itself was adapted from [this](https://github.com/shakes76/PatternFlow/tree/master/recognition/MySolution) source, modifying the output layers and filter size for this problem.

## Results

After some tuning, a preliminary test run of 50 epochs shows only a slight increase in accuracy for the validation data while the training data improves significantly. This disparity could be a result of overfitting, although due to the variance in the validation data it also seemed likely that the learning rate was too high.

![alt text](graphs/Dice_Coeff1.png)

In hopes to improve this the learning rate was lowered from $10^{-3}$ to $10^{-4}$ giving the following result.

![alt text](graphs/Dice_Coeff.png)

As expected the training data is learned slower, however there is not a clear difference in results for the validation data. In both instances the validation dice score roughly oscillates between 0.7 and 0.725.

The higher learning rate gave a slightly better evaluation on the test set with a dice score of 0.7507 compared to 0.7462. Which is not significant enough to suggest the learning rate is a meaningful factor in the models lack of accuracy. Regardless the first (higher) learning rate is used for model evaluation. Other hyperparameters were tested with other values such as filter size, batch size and number of epochs. However they gave similarly ineffective results with the validation data stagnating quickly.

Running inference on this iteration of the model gives these segmentation masks for each class


<img src=graphs/Prediction_0.png width="500"> <img src=graphs/True_0.png width="500">
<img src=graphs/Prediction_1.png width="500"> <img src=graphs/True_1.png width="500">
<img src=graphs/Prediction_2.png width="500"> <img src=graphs/True_2.png width="500">
<img src=graphs/Prediction_3.png width="500"> <img src=graphs/True_3.png width="500">
<img src=graphs/Prediction_4.png width="500"> <img src=graphs/True_4.png width="500">
<img src=graphs/Prediction_5.png width="500"> <img src=graphs/True_5.png width="500">

Clearly this result leaves much to be desired and while the model may not be very useful in its current state it provides a starting point for further study.


## Dependencies and Reproducibility
The model is implemented using tensorflow 2.17.0 utilising the nibabel library to read nifti images. The parameters in the provided model are the same that were used to retrieve the displayed results. Which should allow the results to be reproduced or built upon if desired.

## References
[1] https://www.cancer.org.au/cancer-information/types-of-cancer/prostate-cancer

[2] https://medium.com/analytics-vidhya/what-is-unet-157314c87634

[3] https://github.com/NITR098/Awesome-U-Net
Code for UNet modified from https://github.com/shakes76/PatternFlow/blob/master/recognition/MySolution/Methods.ipynb
137 changes: 137 additions & 0 deletions recognition/2D_UNet_46991638/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
"""
This file contains functions which are used for the loading
and preprocessing of the dataset
"""
import numpy as np
import nibabel as nib
from tqdm import tqdm
import glob
from matplotlib import pyplot
from matplotlib import image
import tensorflow as tf
import skimage.transform as sk

def to_channels(arr:np.ndarray, dtype = np.uint8) -> np.ndarray:
channels = np.unique(arr)
res = np.zeros(arr.shape + (6,), dtype = dtype)
for c in channels:
c = int(c)
res [... , c : c +1][arr == c] = 1
return res
# load medical image functions
def load_data_2D (imageNames, normImage = False, categorical = False, dtype = np.float32,
getAffines = False, early_stop = False):
'''
Load medical image data from names , cases list provided into a list for each .

This function pre - allocates 4 D arrays for conv2d to avoid excessive memory usage .

normImage : bool(normalise the image 0.0 - 1.0)
early_stop : Stop loading pre-maturely , leaves arrays mostly empty , for quick loading and testing scripts .
'''
affines = []
# get fixed size
num = len(imageNames)
first_case = nib.load(imageNames[0]).get_fdata(caching="unchanged")
#rescale image
if categorical:
#neareat-neigboor interpolation
first_case = sk.resize(first_case, (256,256), order=0, anti_aliasing=False, preserve_range=True)
else:
#bi-linear
first_case = sk.resize(first_case, (256, 256), order=1, anti_aliasing=True, preserve_range=True)
if len(first_case.shape) == 3:
first_case = first_case [: ,: ,0] # sometimes extra dims , remove
if categorical:
first_case = to_channels(first_case, dtype = dtype)
rows, cols, channels = first_case.shape
images = np.zeros((num, rows, cols, channels), dtype = dtype)
else:
rows, cols = first_case.shape
images = np.zeros((num, rows, cols), dtype = dtype)
for i, inName in enumerate (tqdm(imageNames)):
niftiImage = nib.load(inName)
inImage = niftiImage.get_fdata(caching ="unchanged") # read disk only

if categorical:
#neareat-neigboor interpolation
inImage = sk.resize(inImage, (256,256), order=0, anti_aliasing=False, preserve_range=True)
else:
#bi-linear
inImage = sk.resize(inImage, (256, 256), order=1, anti_aliasing=True, preserve_range=True)

affine = niftiImage.affine
if len (inImage.shape ) == 3:
inImage = inImage[: ,: ,0] # sometimes extra dims in HipMRI_study data
inImage = inImage.astype(dtype)
if normImage:
inImage = (inImage - inImage.mean()) / inImage.std()
if categorical:
inImage = to_channels(inImage, dtype = dtype)
images[i ,: ,: ,:] = inImage
else:
images[i ,: ,:] = inImage

affines.append(affine)
if i > 20 and early_stop:
break
if getAffines:
return images, affines
else:
return images

#Load all the nifti images in the given apth
def load(path, label=False):
image_list = []
for filename in glob.glob(path + '/*.nii.gz'):
image_list.append(filename)
train_set = load_data_2D(image_list, normImage=False, categorical=label)
return train_set

#Normalise and scale to [0, 1]
def process_data(train_set):
# the method normalizes training images and adds 4th dimention

train_set = (train_set - np.mean(train_set))/ np.std(train_set)
train_set= (train_set- np.amin(train_set))/ np.amax(train_set- np.amin(train_set))
train_set = train_set[...,np.newaxis]
return train_set

#Data Generator class made with assitance of ChatGPT
class DataGenerator(tf.keras.utils.Sequence):
def __init__(self, x_set, y_set, batch_size):
self.x, self.y = x_set, y_set
self.batch_size = batch_size

def __len__(self):
return int(np.floor(len(self.x) / self.batch_size))

def __getitem__(self, idx):
batch_x = self.x[idx * self.batch_size:(idx + 1) * self.batch_size]
batch_y = self.y[idx * self.batch_size:(idx + 1) * self.batch_size]
return batch_x, batch_y

def get_X_data(path, only_test = False):
path += "keras_slices_"
if not only_test:
train_X = load(path + "train")
validate_X = load(path + "validate")
test_X = load(path + "test")
if not only_test:
train_X = process_data(train_X)
validate_X = process_data(validate_X)
test_X = process_data(test_X)
if not only_test:
return train_X, validate_X, test_X
return test_X

def get_Y_data(path, only_test = False):
path += "keras_slices_seg_"
if not only_test:
train_Y = load(path + "train", label=True)
validate_Y = load(path + "validate", label=True)
test_Y = load(path + "test", label=True)

if only_test:
return test_Y
return train_Y, validate_Y, test_Y
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added recognition/2D_UNet_46991638/graphs/True_0.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added recognition/2D_UNet_46991638/graphs/True_1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added recognition/2D_UNet_46991638/graphs/True_2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added recognition/2D_UNet_46991638/graphs/True_3.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added recognition/2D_UNet_46991638/graphs/True_4.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added recognition/2D_UNet_46991638/graphs/True_5.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added recognition/2D_UNet_46991638/graphs/UNet.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
176 changes: 176 additions & 0 deletions recognition/2D_UNet_46991638/modules.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
"""
This file contains the UNet model used for segmentation
"""
import tensorflow as tf
#Modified from https://github.com/shakes76/PatternFlow/blob/master/recognition/MySolution/Methods.ipynb
def unet_model ():
filter_size=16
input_layer = tf.keras.Input((256,256,1))

pre_conv = tf.keras.layers.Conv2D(filter_size * 1, (3, 3), padding="same")(input_layer)
pre_conv = tf.keras.layers.LeakyReLU(negative_slope=.01)(pre_conv)


# context module 1 pre-activation residual block
conv1 = tf.keras.layers.BatchNormalization()(pre_conv)
conv1 = tf.keras.layers.LeakyReLU(negative_slope=.01)(conv1)
conv1 = tf.keras.layers.Conv2D(filter_size * 1, (3, 3), padding="same" )(conv1)
conv1 = tf.keras.layers.Dropout(.3) (conv1)
conv1 = tf.keras.layers.BatchNormalization()(conv1)
conv1 = tf.keras.layers.LeakyReLU(negative_slope=.01)(conv1)
conv1 = tf.keras.layers.Conv2D(filter_size * 1, (3, 3), padding="same")(conv1)
conv1 = tf.keras.layers.Add()([pre_conv,conv1])

# downsample and double number of feature maps
pool1 = tf.keras.layers.Conv2D(filter_size * 2, (3,3), (2,2) , padding='same')(conv1)
pool1 = tf.keras.layers.LeakyReLU(negative_slope=.01)(pool1)

# context module 2
conv2 = tf.keras.layers.BatchNormalization()(pool1)
conv2 = tf.keras.layers.LeakyReLU(negative_slope=.01)(conv2)
conv2 = tf.keras.layers.Conv2D(filter_size * 2, (3, 3), padding="same")(conv2)
conv2 = tf.keras.layers.Dropout(.3) (conv2)
conv2 = tf.keras.layers.BatchNormalization()(conv2)
conv2 = tf.keras.layers.LeakyReLU(negative_slope=.01)(conv2)
conv2 = tf.keras.layers.Conv2D(filter_size * 2, (3, 3), padding="same")(conv2)
conv2 = tf.keras.layers.Add()([pool1,conv2])

# downsample and double number of feature maps
pool2 = tf.keras.layers.Conv2D(filter_size*4, (3,3),(2,2), padding='same')(conv2)
pool2 = tf.keras.layers.LeakyReLU(negative_slope=.01)(pool2)

# context module 3
conv3 = tf.keras.layers.BatchNormalization()(pool2)
conv3 = tf.keras.layers.LeakyReLU(negative_slope=.01)(conv3)
conv3 = tf.keras.layers.Conv2D(filter_size * 4, (3, 3), padding="same")(conv3)
conv3 = tf.keras.layers.Dropout(.3) (conv3)
conv3 = tf.keras.layers.BatchNormalization()(conv3)
conv3 = tf.keras.layers.LeakyReLU(negative_slope=.01)(conv3)
conv3 = tf.keras.layers.Conv2D(filter_size * 4, (3, 3), padding="same")(conv3)
conv3 = tf.keras.layers.Add()([pool2,conv3])

# downsample and double number of feature maps
pool3 = tf.keras.layers.Conv2D(filter_size*8, (3,3),(2,2),padding='same')(conv3)
pool3 = tf.keras.layers.LeakyReLU(negative_slope=.01)(pool3)

# context module 4
conv4 = tf.keras.layers.BatchNormalization()(pool3)
conv4 = tf.keras.layers.LeakyReLU(negative_slope=.01)(conv4)
conv4 = tf.keras.layers.Conv2D(filter_size * 8, (3, 3), padding="same")(conv4)
conv4 = tf.keras.layers.Dropout(.3) (conv4)
conv4 = tf.keras.layers.BatchNormalization()(conv4)
conv4 = tf.keras.layers.LeakyReLU(negative_slope=.01)(conv4)
conv4 = tf.keras.layers.Conv2D(filter_size * 8, (3, 3), padding="same")(conv4)
conv4 = tf.keras.layers.Add()([pool3,conv4])


# downsample and double number of feature maps
pool4 = tf.keras.layers.Conv2D(filter_size*16, (3,3),(2,2),padding='same')(conv4)
pool4 = tf.keras.layers.LeakyReLU(negative_slope=.01)(pool4)

# context module 5
# Middle
convm = tf.keras.layers.BatchNormalization()(pool4)
convm = tf.keras.layers.LeakyReLU(negative_slope=.01)(convm)
convm = tf.keras.layers.Conv2D(filter_size * 16, (3, 3), padding="same")(convm)
convm = tf.keras.layers.Dropout(.3) (convm)
convm = tf.keras.layers.BatchNormalization()(convm)
convm = tf.keras.layers.LeakyReLU(negative_slope=.01)(convm)
convm = tf.keras.layers.Conv2D(filter_size * 16, (3, 3), padding="same")(convm)
convm = tf.keras.layers.Add()([pool4,convm])


#upsampling module 1
deconv4 = tf.keras.layers.UpSampling2D(size=(2,2) , interpolation='bilinear')(convm)
deconv4 = tf.keras.layers.Conv2D (filter_size *8, (3, 3) , padding="same")(deconv4)
deconv4 = tf.keras.layers.LeakyReLU(negative_slope=.01)(deconv4)


#concatatinate layers
uconv4 = tf.keras.layers.concatenate([deconv4, conv4], axis=3)


#localization module 1
uconv4 = tf.keras.layers.Conv2D(filter_size * 16, (3, 3) , padding="same")(uconv4)
uconv4 = tf.keras.layers.BatchNormalization()(uconv4)
uconv4 = tf.keras.layers.LeakyReLU(negative_slope=.01)(uconv4)
uconv4 = tf.keras.layers.Conv2D(filter_size * 8, (1, 1), padding="same")(uconv4)
uconv4 = tf.keras.layers.BatchNormalization()(uconv4)
uconv4 = tf.keras.layers.LeakyReLU(negative_slope=.01)(uconv4)

#upsampling module 2
deconv3 = tf.keras.layers.UpSampling2D(size=(2,2) , interpolation='bilinear')(uconv4)
deconv3 = tf.keras.layers.Conv2D (filter_size *4, (3, 3) , padding="same")(deconv3)
deconv3 = tf.keras.layers.LeakyReLU(negative_slope=.01)(deconv3)



# concatatinate layers
uconv3 = tf.keras.layers.concatenate([deconv3, conv3], axis=3)


# localization module 2
uconv3 = tf.keras.layers.Conv2D(filter_size * 8, (3, 3), padding="same")(uconv3)
uconv3 = tf.keras.layers.BatchNormalization()(uconv3)
uconv3 = tf.keras.layers.LeakyReLU(negative_slope=.01)(uconv3)
uconv3 = tf.keras.layers.Conv2D(filter_size * 4, (1, 1), padding="same")(uconv3)
uconv3 = tf.keras.layers.BatchNormalization()(uconv3)
uconv3 = tf.keras.layers.LeakyReLU(negative_slope=.01)(uconv3)

# segmentation layer 1
seg3 = tf.keras.layers.Conv2D(6, (3,3), activation="softmax", padding='same')(uconv3)
# upscale segmented layer 1
seg3 = tf.keras.layers.UpSampling2D(size=(2,2) , interpolation='bilinear')(seg3)


# Upsample module 3
deconv2 = tf.keras.layers.UpSampling2D(size=(2,2) , interpolation='bilinear')(uconv3)
deconv2 = tf.keras.layers.Conv2D (filter_size *2, (3, 3) , padding="same")(deconv2)
deconv2 = tf.keras.layers.LeakyReLU(negative_slope=.01)(deconv2)


# concatination layer
uconv2 = tf.keras.layers.concatenate([deconv2, conv2], axis=3)


# localization module 3
uconv2 = tf.keras.layers.Conv2D(filter_size * 4, (3, 3), padding="same")(uconv2)
uconv2 = tf.keras.layers.BatchNormalization()(uconv2)
uconv2 = tf.keras.layers.LeakyReLU(negative_slope=.01)(uconv2)
uconv2 = tf.keras.layers.Conv2D(filter_size * 2, (1, 1), padding="same")(uconv2)
uconv2 = tf.keras.layers.BatchNormalization()(uconv2)
uconv2 = tf.keras.layers.LeakyReLU(negative_slope=.01)(uconv2)

# segmentation layer 2
seg2 = tf.keras.layers.Conv2D(6, (3,3), activation="softmax", padding='same')(uconv2)

# add segmentation layer 1 and 2
seg_32 = tf.keras.layers.Add()([seg3,seg2])
# upscale sum segmentation layer 1 and 2
seg_32 = tf.keras.layers.UpSampling2D(size=(2,2) , interpolation='bilinear')(seg_32)


# Upsample module 4
deconv1 = tf.keras.layers.UpSampling2D(size=(2,2) , interpolation='bilinear')(uconv2)
deconv1 = tf.keras.layers.Conv2D (filter_size *1, (3, 3) , padding="same")(deconv1)
deconv1 = tf.keras.layers.LeakyReLU(negative_slope=.01)(deconv1)


# concatination layer
uconv1 = tf.keras.layers.concatenate([deconv1, conv1], axis=3 )

#final convolution layer
uconv1 = tf.keras.layers.Conv2D(filter_size * 2, (3, 3), padding="same")(uconv1)
uconv1 = tf.keras.layers.BatchNormalization()(uconv1)
uconv1 = tf.keras.layers.LeakyReLU(negative_slope=.01)(uconv1)

# final segmentation layer
seg1 = tf.keras.layers.Conv2D(6, (3,3), activation="softmax", padding='same' )(uconv1)

# sum all segmentation layers
seg_sum = tf.keras.layers.Add()([seg1,seg_32])


output_layer = tf.keras.layers.Conv2D(6, (3,3), padding='same' ,activation="softmax")(seg_sum)
model = tf.keras.Model( input_layer , outputs=output_layer)
return model
Loading