Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Segmenting MRI images using 2D Unet #142

Open
wants to merge 24 commits into
base: topic-recognition
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
26f41ff
Modified dataset.py to add the dataloader for the three keras_slices …
Oct 7, 2024
d2be274
added the loading of the segmented data to dataset.py as previously i…
Oct 7, 2024
10fb337
Added encoder, bottleneck, decoder and unet functions to modules.py.
Oct 10, 2024
e0f1c20
Tested the functionality of the model with a subset of data.
Oct 10, 2024
6ab6d73
Adding files to rangpur to test model. using 50 epochs and a batch si…
Oct 10, 2024
b7c029a
Added loss, accuracy and dice score variables to store the results of…
Oct 11, 2024
4bb12fe
Modified train to use Binary cross entropy, this allows it to run on …
Oct 16, 2024
4266105
Added code to predict.py to show the first 5 original images,
Oct 21, 2024
3e732fb
Cleaned imports and added dice
Oct 21, 2024
e783c23
Modifying training.py for accuracy
Oct 21, 2024
679dc2c
Sorted data and changed folder
Oct 21, 2024
986328a
Added a learning rate scheduler to help dice score
Oct 21, 2024
5575436
Changed to combined loss function better accuracy
Oct 22, 2024
4cc236c
Fixed git issues
Oct 26, 2024
56460f4
Added plotting functions to predict.py
Oct 26, 2024
8ed55f8
Ran the unet with 3, 6 and 12 epochs and saved the images
Oct 26, 2024
f9d7a57
Going to do a pull request
Oct 26, 2024
a511b70
Delete README.md
NicMarchant Oct 26, 2024
7007c17
Delete recognition/README.md
NicMarchant Oct 26, 2024
01c0a23
Delete recognition/dataset.py
NicMarchant Oct 26, 2024
be7f9fa
Delete recognition/modules.py
NicMarchant Oct 26, 2024
4758cd6
Delete recognition/predict.py
NicMarchant Oct 26, 2024
2ed4855
Delete recognition/train.py
NicMarchant Oct 26, 2024
bddf312
Delete recognition/__pycache__ directory
NicMarchant Oct 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 0 additions & 19 deletions README.md

This file was deleted.

10 changes: 0 additions & 10 deletions recognition/README.md

This file was deleted.

83 changes: 83 additions & 0 deletions recognition/unet_hipmri_s4646244/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# 2D Unet to segment MRI images of prostate cancer
This repository contains a TensorFlow Keras implimentation of a Binary classification Unet to segment prostate cancer MRI scans.

#### Files:
dataset.py
- A file used to read in the .nii files into training, validation and test sets
modules.py

- A file containing the Unet architecture, including the Encoder, decoder and bottleneck

train.py
- A file containing the functions to train the unet and save the model

predict.py
- A file to show how well the Unet predicts test images

#### Folders
Unet_images
- A folder containing the testing results of running different epochs

### Model
###### Filters
The filters applied during the encoder step are [64,128,256,512] and [512,256,128,64] for the decoder.
###### Encoder
The encoder takes the input tensor and for each filter in the list completes two convolutions with a 3x3 kernel size and a relu activation function.
It then saves the resulting tensor into the skip connection so it can be used in the decoder.
A 2x2 max pooling is then used and the resulting tensor output.
This happens four times, once for each filter.

###### Bottleneck
The bottleneck step applies two convolutions using a 3x3 kernel and the relu activation function.
it then outputs the tensor to be used in the decoder

###### Decoder
The decoder then takes this tensor, completes an up convolution using a 2x2 kernel and a relu function
The relevent skip connection is then concatenated with the tensor from the upconv.
This concated tensor is then fed through two convolutions using a 3x3 kernel and relu activiation
One last convolution using the sigmoid function is used to complete the binary classification.

### Training
The Unet is trained with an adam optimiser with an initial learning rate of 0.0001. It uses a combined loss function that consists of binary cross entropy and dice loss as this resulted in the best performance.
The training also uses early stoppage, this stops the training when the validation loss performance does not increase after three consecutive epochs.
A learning rate scheduler is also used to monitor the validation loss, if the validation does not improve after two epochs the learning rate is halved.
In the results twelve epochs were used to reduce the training time however only 6 were run as the model stopped early.
It also uses a validation set to evaluate performance after each epoch, helping to monitor overfitting by checking how well the model performs on new data.

### Testing
After the model is trained it then is given new data to segment. This data is

### Performance
The model was run using a batch size of four and twelve epochs, only six were run as it stopped early.
The results for the model being trained on three and six epochs can be seen in the Unet_images folder.

The model is very good at binary segmentaion of the prostate cancer images, If i had more time I would convert it to do multi class segmentation.
It segments most of the regions well except very small areas.
It can be seen that the mean dice test score is just above 0.65 however there are many datapoints that fall below this region.
When looking at the dice coefficents over each epoch it sharply increases and then slowly tapers off, this is the same for the loss function except it sharply decreases.

#### Required dependencies
- TensorFlow (for Keras layers, models, and callbacks)
- NumPy (for numerical operations)
- Matplotlib (for plotting)
- NiBabel (for neuroimaging data handling)
- tqdm (for progress bars)
- scikit-image (for image transformations)
- pathlib (for filesystem path manipulations)

#### Future improvements
The dataset given is for multiclass segmentation. The implimentation of my unet and training is only for binary classification, I tried implimenting the multiclass model by using a softmax activation function rather than sigmoid and modifying my train.py to handle the multiple classes however I could not get it working. In the future I will look into modifying the implimentation so that it can do multiclass segmentation.

#### How to run
Before running the model the relevent files paths need to be added into dataset.py
Once this is done all that is needed to be run is the predict.py file with no arguments.
This will train, validate, test and print the results of the model.

If a powerful graphics card is in your system, you may be able to increase the batch size in train.py this will result in faster training.

### References
Reference for Dice coefficient metric implementation in the train.py function
Stack Overflow. "Dice coefficient not increasing for U-Net image segmentation."
https://stackoverflow.com/questions/67018431/dice-coefficent-not-increasing-for-u-net-image-segmentation


Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
107 changes: 107 additions & 0 deletions recognition/unet_hipmri_s4646244/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import numpy as np
import nibabel as nib
from tqdm import tqdm
import skimage.transform as skTrans
from pathlib import Path
import tensorflow as tf

def to_channels(arr: np.ndarray, dtype=np.uint8) -> np.ndarray:
channels = np.unique(arr)
res = np.zeros(arr.shape + (len(channels),), dtype=dtype)
for c in channels:
c = int(c)
res[..., c:c + 1][arr == c] = 1
return res

# load medical image functions
def load_data_2D(imageNames, normImage=False, categorical=False, dtype=np.float32, getAffines=False, early_stop=False):
'''
Load medical image data from names, cases list provided into a list for each.

This function pre-allocates 4D arrays for conv2d to avoid excessive memory usage.

normImage: bool (normalise the image 0.0 -1.0)
early_stop: Stop loading prematurely, leaves arrays mostly empty, for quick loading and testing scripts.
'''
affines = []

# get fixed size
num = len(imageNames)
first_case = nib.load(imageNames[0]).get_fdata(caching='unchanged')
if len(first_case.shape) == 3:
first_case = first_case[:, :, 0] # sometimes extra dims, remove
if categorical:
first_case = to_channels(first_case, dtype=dtype)
rows, cols, channels = first_case.shape
images = np.zeros((num, rows, cols, channels), dtype=dtype)
else:
rows, cols = first_case.shape
images = np.zeros((num, rows, cols), dtype=dtype)

for i, inName in enumerate(tqdm(imageNames)):
niftiImage = nib.load(inName)
inImage = niftiImage.get_fdata(caching='unchanged') # read disk only

affine = niftiImage.affine
if len(inImage.shape) == 3:
inImage = inImage[:, :, 0] # sometimes extra dims in HipMRI_study data

# Converts the image to a 256,128 image
inImage = skTrans.resize(inImage, (256, 128), order=1, preserve_range=True)

inImage = inImage.astype(dtype)
if normImage:
# ~ inImage = inImage / np.linalg.norm(inImage)
# ~ inImage = 255. * inImage / inImage.max()
inImage = (inImage - inImage.mean()) / inImage.std()

if categorical:
inImage = utils.to_channels(inImage, dtype=dtype)
images[i, :, :, :] = inImage

else:
images[i, :, :] = inImage

affines.append(affine)
if i > 20 and early_stop:
break

if getAffines:
return images, affines
else:
return images

testDir = '/home/kankuna/Documents/COMP3710DATA/HipMRI_study_keras_slices_data/keras_slices_test/'
trainDir = '/home/kankuna/Documents/COMP3710DATA/HipMRI_study_keras_slices_data/keras_slices_train.large/'
validateDir = '/home/kankuna/Documents/COMP3710DATA/HipMRI_study_keras_slices_data/keras_slices_validate/'

testSegDir = '/home/kankuna/Documents/COMP3710DATA/HipMRI_study_keras_slices_data/keras_slices_seg_test/'
trainSegDir = '/home/kankuna/Documents/COMP3710DATA/HipMRI_study_keras_slices_data/keras_slices_seg_train.large/'
validateSegDir = '/home/kankuna/Documents/COMP3710DATA/HipMRI_study_keras_slices_data/keras_slices_seg_validate/'

# Load the scans
from pathlib import Path

# Load the test images
testListNii = sorted(Path(testDir).glob('*.nii'))
testImages = load_data_2D(testListNii, normImage=True, categorical=False)

# Load the training images
trainListNii = sorted(Path(trainDir).glob('*.nii'))
trainImages = load_data_2D(trainListNii, normImage=True, categorical=False)

# Load the validation images
validateListNii = sorted(Path(validateDir).glob('*.nii'))
validateImages = load_data_2D(validateListNii, normImage=True, categorical=False)

# Load the segmented test scans
testSegListNii = sorted(Path(testSegDir).glob('*.nii'))
testSegImages = load_data_2D(testSegListNii, normImage=True, categorical=False)

# Load the segmented training scans
trainSegListNii = sorted(Path(trainSegDir).glob('*.nii'))
trainSegImages = load_data_2D(trainSegListNii, normImage=True, categorical=False)

# Load the segmented validation scans
validateSegListNii = sorted(Path(validateSegDir).glob('*.nii'))
validateSegImages = load_data_2D(validateSegListNii, normImage=True, categorical=False)
51 changes: 51 additions & 0 deletions recognition/unet_hipmri_s4646244/modules.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, Conv2DTranspose, concatenate # type: ignore
from tensorflow.keras.models import Model # type: ignore

# List of filters to be applied in the encoding and reverse in the decoding step
filterList = [64,128,256,512]

# Function that does the encoding of the Unet by applying convolutions and max pooling to the given tensor at each filter level
# Parameters: inputTensor, a tf tensor that will be encoded
# Returns: (tensor, skipConnectionList), a tuple containing the resulting encoded tensor and a list of the skip connections at each step
def encoder(inputTensor):
skipConnectionList = []
tensor = inputTensor
for filter in filterList:
firstConv = Conv2D(filter, kernel_size = (3,3), padding = 'same', strides = 1, activation = 'relu')(tensor)
secondConv = Conv2D(filter, kernel_size = (3,3), padding = 'same', strides = 1, activation = 'relu')(firstConv)
skipConnectionList.append(secondConv)
tensor = MaxPooling2D(pool_size = (2,2), padding = 'same')(secondConv)
return tensor, skipConnectionList

# Function that does the decoding of the Unet by applying an up convolution followed by concatting the tensor with the skip connection and then applying convolutions.
# Parameters: inputTensor, a tf tensor that will be decoded
# Returns: tensor, a decoded tensor
def decoder(skipConnectionList, inputTensor):
tensor = inputTensor
for filter in reversed(filterList):
upConv = Conv2DTranspose(filter, kernel_size = (2,2), padding = 'same', activation = 'relu', strides = 2)(tensor)
skipConnection = skipConnectionList.pop()
concatTensor = concatenate([upConv, skipConnection])
firstConv = Conv2D(filter, kernel_size = (3,3), padding = 'same', strides = 1, activation = 'relu')(concatTensor)
secondConv = Conv2D(filter, kernel_size = (3,3), padding = 'same', strides = 1, activation = 'relu')(firstConv)
tensor = secondConv
finalConv = Conv2D(1, kernel_size=(1, 1), padding='same', strides=1, activation='sigmoid')(tensor)
return finalConv

# Function that applies the bottleneck of the unet.
# Parameters: inputTensor, a tf tensor that will be decoded
# Returns: tensor, a tensor that has gone through two convolutions
def bottleneck(inputTensor):
firstConv = Conv2D(1024, kernel_size = (3,3), padding = 'same', strides = 1, activation = 'relu')(inputTensor)
secondConv = Conv2D(1024, kernel_size = (3,3), padding = 'same', strides = 1, activation = 'relu')(firstConv)
tensor = secondConv
return tensor

# Function that Applies the encoder, bottleneck and decoder into one unet model
# Returns a keras model of the unet
def unet():
inputs = Input(shape = (256, 128, 1))
encodedResult, skipConnectionList = encoder(inputs)
bottleneckResult = bottleneck(encodedResult)
decodedResult = decoder(skipConnectionList, bottleneckResult)
return Model(inputs=[inputs], outputs=[decodedResult])
86 changes: 86 additions & 0 deletions recognition/unet_hipmri_s4646244/predict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import train
import matplotlib.pyplot as plt
from train import unetModel, dice_metric, trainResults
from dataset import testImages, testSegImages
import numpy as np

testPredictedSeg = unetModel.predict(testImages)
print(np.unique(testPredictedSeg))

#Function to find the dice score of each set of actual segments and predicted
def calculate_dice_scores(y_true, y_pred):
dice_scores = []
for i in range(len(y_true)):
y_pred_squeezed = np.squeeze(y_pred[i])
score = dice_metric(y_pred_squeezed, y_true[i]).numpy()
dice_scores.append(score)
return dice_scores

dice_scores = calculate_dice_scores(testSegImages, testPredictedSeg)
dice_scores = np.array(dice_scores)

#Print the actual image, actual segment and predicted segment
fig, pos = plt.subplots(5, 3, figsize=(15, 25))
for i in range(5):
# Display original image
pos[i, 0].imshow(testImages[i].squeeze())
pos[i, 0].set_title(f'Original image {i+1}')
pos[i, 0].axis('off')

# Display actual segmentation
pos[i, 1].imshow(testSegImages[i].squeeze())
pos[i, 1].set_title(f'Actual segmentation {i+1}')
pos[i, 1].axis('off')

# Display predicted segmentation
pos[i, 2].imshow(testPredictedSeg[i].squeeze())
pos[i, 2].set_title(f'Predicted segmentation {i+1}')
pos[i, 2].axis('off')
plt.tight_layout()
plt.show()

#print the dice scores for each image and the distribution
plt.figure(figsize=(12, 6))
plt.plot(dice_scores, marker='o', linestyle='None', color='b')
plt.title('Dice Scores for Each Test Image')
plt.xlabel('Test Image Index')
plt.ylabel('Dice Score')
plt.ylim(0, 1)
plt.yticks(np.linspace(0, 1, num=11))
plt.grid()
plt.axhline(y=np.mean(dice_scores), color='r', linestyle='--', label='Mean Dice Score')
plt.legend()
plt.show()

plt.figure(figsize=(12, 6))
plt.hist(dice_scores, bins=10, color='c', edgecolor='black', alpha=0.7)
plt.title('Distribution of Dice Scores')
plt.xlabel('Dice Score')
plt.ylabel('Frequency')
plt.xlim(0, 1)
plt.grid()
plt.show()

# Plotting Loss and Dice Coefficient
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(trainResults.history['loss'], label='Training Loss')
plt.plot(trainResults.history['val_loss'], label='Validation Loss')
plt.title('Loss Over Epochs')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.grid()

# Plot Dice Coefficient
plt.subplot(1, 2, 2)
plt.plot(trainResults.history['dice_metric'], label='Training Dice Coefficient')
plt.plot(trainResults.history['val_dice_metric'], label='Validation Dice Coefficient')
plt.title('Dice Coefficient Over Epochs')
plt.xlabel('Epochs')
plt.ylabel('Dice Coefficient')
plt.legend()
plt.grid()

plt.tight_layout()
plt.show()
Loading