Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

46974426 - pull request for merging 2d UNet into PatternAnalysis-2024 topic recognition branch #138

Open
wants to merge 28 commits into
base: topic-recognition
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
0faa320
topic-recognition - initial commit including project report readme fi…
Oct 22, 2024
1da32d8
topic-recognition - created blank python files for modules, dataset, …
Oct 22, 2024
443b340
topic-recognition - dataset python populated with 2D data loading fun…
Oct 22, 2024
bfd5313
topic-recognition - check if cude is available (returns true)
Oct 22, 2024
ce88810
topic-recognition - quick test to load and display one of the slices …
Oct 22, 2024
10a4478
topic-recognition - title and reference to initial slice image added …
Oct 22, 2024
bfc81fc
topic-recognition - tested image loading implementation and resize fu…
Oct 23, 2024
60eeea5
topic-recognition - updated README report
Oct 23, 2024
6772a7d
topic-recognition - added first cut of UNet initialisation to module…
Oct 23, 2024
bb8db33
topic-recognition - added first cut of dice_loss function to modules.py
Oct 23, 2024
e1da81d
topic-recognition - dice_loss function comments added
Oct 23, 2024
ed12eee
topic-recognition - Figure_1.png renamed to be more meaningfull and c…
Oct 23, 2024
d523966
topic-recognition - first cut of training functionality added
Oct 23, 2024
a0f5f40
topic-recognition - changes to README report
Oct 23, 2024
405a938
topic-recognition - added accuracy and loss tracking to plot accuracy…
Oct 23, 2024
aa87eea
topic-recognition - added trained model saving after for use by test.py
Oct 23, 2024
12bc644
topic-recognition - remove unecessary test.py file and some code added
Oct 23, 2024
2356e79
topic-recognition - first compilable and running version (ran on goog…
Oct 24, 2024
43fe33e
topic-recognition - cleared warning/errors by referencing each file a…
Oct 24, 2024
4d4d7f2
topic-recognition - fixed data paths and updated readme report
Oct 24, 2024
dd140ab
topic-recognition - first successful run (need to add validation data…
Oct 25, 2024
28610a2
topic-recognition - validation data added instead of data split
Oct 25, 2024
cf38b5b
topic-recognition - updates to train script and report
Oct 25, 2024
04b1e25
topic-recognition - more updates
Oct 25, 2024
0c04072
topic-recognition - spelling mistakes in report cleared up
Oct 25, 2024
5a223fd
topic-recognition - updated script comments and updated the report
Oct 25, 2024
ea1d850
topic-recognition - comment updated
Oct 25, 2024
c7f54eb
topic-recognition - could not update pdf submission note in report
Oct 25, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions recognition/2d_unet_s46974426/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
PLEASE NOTE: the version is different from pdf submission because could not re-submit?

COMP3710 2D UNet Report

Using 2D UNet to segment the HipMRI Study on Prostate Cancer dataset

The task for this report was to create 4 files, modules.py, train.py, dataset.py and predict.py to
load and segment the HipMRI study as a 2d Unet using Pytorch and the direct task description taken
from Blackboard is below.

Task Description from Blackboard: "Segment the HipMRI Study on Prostate Cancer (see Appendix for link)
using the processed 2D slices (2D images) available here with the 2D UNet [1] with all labels having a
minimum Dice similarity coefficient of 0.75 on the test set on the prostate label. You will need to load
Nifti file format and sample code is provided in Appendix B. [Easy Difficulty]"

I quickly want to mention that I prefixed each commit with 'topic recognition' this was a force of habit,
typically when I have worked on git repositories I first branch the solution named after a change request
e.g. "CR-123" and prefix each commit with the name of the CR.

An initial test code was run to just visualise one of the slices before using 2D UNet to get a sense of
what the images look like. The resulting image after test.py was run can be seen in
slice_print_from_initial_test in the images folder.

The data loader was run in a simple for to check that it worked, it was ~50% successful when it errored due
to image sizing issue. To resolve this, an image resizing function was added to be called by the data loader.
The completed data_loader test output can be seen in data_loader_test.png in the images folder.

After messing around with fixing errors from the original versions I had tried of the modules, dataset, predict
and train scripts I eventually gave up as they would not run.

I went online and found a similar example of a 2d UNet implemented using pytorch and adapted the code to suit
my problem and reference to this repository can be seen below.

Author: milesial
Date: 11/02/2024
Title of program/source code: U-Net: Semantic segmentation with PyTorch
Code version: 475
Type (e.g. computer program, source code): computer program
Web address or publisher (e.g. program publisher, URL): https://github.com/milesial/Pytorch-UNet

Also, during this process, I discovered that the masks were in the segment datasets and the images were
in the datasets not suffixed with 'seg' (I had it the wrong way around originally).

After attempting to run the train.py file and fixing errors as they occurred, I was eventually able to
run the train.py code in full to generate some loss and dice coefficient-based validation plots.

I ran the train code for the first 5 epochs and a graph showing the batch loss and a graph showing the
dice score can both be seen in the images folder. I then ran it for 50 epochs and the graphs similar to
above are in the images folder. The console running progress can also be seen in the console_running image
in the images folder.

This final part will outline a description of working principles of the algorithm and the problem it solves.
The Pytorch UNet is comprised of four parts, an encoder, decoder, bottleneck and a convolutional layer.
The modules script contains the UNet’s definition. It also includes the dice coefficient handling to calculate
dice loss which measures the overlap of two images in order to quantify a segmentation model’s accuracy.
I also added a function to combine two datasets (the segment images and masks), this is because datasets
what include both segments and masks are typically used in UNet algorithms. The modules script also included
some basic dataset classes, a method to load images, check uniqueness of masks and some basic plotting logic.

The train script initialises and loads the UNet model defined in the modules and then trains it on the
segmentation dataset. Before this is done however, it is transformed and loaded as 2d data using the provided
load_data_2d function in the task appendix. The train script handles defining the main train loop, iterating
over the data in batches, calculating losses and dice scores, which are then plotted after the algorithm has
completed. It also handles saving progress while the training loop completes each epoch, which is made up of
a number of batches (typically 5-6 in this case).

The dataset script just contains the load_data_2d method as seen in the appendix of the task sheet. It also
contains a data transformation function to make the image dimensions consistent.

Finally, the predict script’s purpose is to generate mask predictions of new images on a trained and saved UNet model.


89 changes: 89 additions & 0 deletions recognition/2d_unet_s46974426/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import numpy as np
import nibabel as nib
from tqdm import tqdm
import cv2
import os # Ensure you have this to work with file paths
import torch
from torch.utils.data import Dataset # Add this import for Dataset


'''
Resizes the images to all be consistent (this is required for the data loading process in load_data_2d)

Parameters:
- image: image that is being resized
- target_shape: shape the image is being resized too e.g. 256*128 pixels
'''
def resize_image(image, target_shape):
"""Resize image to the target shape using OpenCV."""
return cv2.resize(image, (target_shape[1], target_shape[0]), interpolation=cv2.INTER_LINEAR)

def to_channels(arr: np.ndarray, dtype=np.uint8) -> np.ndarray:
channels = np.unique(arr)
res = np.zeros(arr.shape + (len(channels),), dtype=dtype)
for c in channels:
c = int(c)
res[..., c:c + 1][arr == c] = 1
return res

def load_data_2D(imageNames, normImage=False, categorical=False, dtype=np.float32, getAffines=False, early_stop=False):
"""
Load medical image data from names, cases list provided into a list for each.

Parameters:
- imageNames: List of image file names
- normImage: bool (normalize the image 0.0 - 1.0)
- categorical: bool (indicates if the data is categorical)
- dtype: Desired data type (default: np.float32)
- getAffines: bool (return affine matrices along with images)
- early_stop: bool (stop loading prematurely for testing purposes)

Returns:
- images: Loaded image data as a numpy array
- affines: List of affine matrices (if getAffines is True)
"""
affines = []
num = len(imageNames)
first_case = nib.load(imageNames[0]).get_fdata(caching='unchanged')

if len(first_case.shape) == 3:
first_case = first_case[:, :, 0] # Remove extra dims if necessary
if categorical:
first_case = to_channels(first_case, dtype=dtype)
rows, cols, channels = first_case.shape
images = np.zeros((num, rows, cols, channels), dtype=dtype)
else:
rows, cols = first_case.shape
images = np.zeros((num, rows, cols), dtype=dtype)

for i, inName in enumerate(tqdm(imageNames)):
niftiImage = nib.load(inName)
inImage = niftiImage.get_fdata(caching='unchanged')
affine = niftiImage.affine

if len(inImage.shape) == 3:
inImage = inImage[:, :, 0] # Remove extra dims if necessary
inImage = inImage.astype(dtype)

# Resize the image if necessary
if inImage.shape != (rows, cols):
inImage = resize_image(inImage, (rows, cols))

if normImage:
inImage = (inImage - inImage.mean()) / inImage.std()

if categorical:
inImage = to_channels(inImage, dtype=dtype)
images[i, :, :, :] = inImage
else:
images[i, :, :] = inImage

affines.append(affine)
if i > 20 and early_stop:
break

if getAffines:
return images, affines
else:
return images

Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading