Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Alzheimer’s Disease Classification Using Vision Transformers (ViT) #182

Open
wants to merge 62 commits into
base: topic-recognition
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
62 commits
Select commit Hold shift + click to select a range
0917845
Implemented basic DataLoader with resizing and tensor conversion.
yttrium400 Oct 30, 2024
69aa94d
Added dataset size calculation for training and testing datasets.
yttrium400 Oct 30, 2024
04517d7
Added random horizontal flip to training data augmentation.
yttrium400 Oct 30, 2024
9b8ebde
Added normalization to both training and test datasets.
yttrium400 Oct 30, 2024
e72f9a9
Enabled shuffling for training data only, set fixed order for test data.
yttrium400 Oct 30, 2024
37dcb6d
Added parallel data loading with num_workers and pin_memory for impro…
yttrium400 Oct 30, 2024
7df33d9
Added informative comments to clarify code structure, functionality, …
yttrium400 Oct 30, 2024
742fa9e
Implemented basic model creation function using ViT architecture.
yttrium400 Oct 31, 2024
c2bc830
Added informative comments to explain model creation and function par…
yttrium400 Oct 31, 2024
c4b7726
Initialized basic script structure with imports and empty function.
yttrium400 Oct 31, 2024
495db6d
Added device setup to use GPU if available.
yttrium400 Oct 31, 2024
d8b5278
Loaded data using DataLoader with dataset sizes.
yttrium400 Oct 31, 2024
e5f8f74
Initialized model and move it to the appropriate device.
yttrium400 Oct 31, 2024
1a41c8c
Set up CrossEntropyLoss and AdamW optimizer.
yttrium400 Oct 31, 2024
f57417f
Added learning rate scheduler to adjust learning rate over epochs.
yttrium400 Oct 31, 2024
69fa74c
Initialized variables to track losses and accuracies.
yttrium400 Oct 31, 2024
4de0b39
Added main training loop structure with forward pass, loss computatio…
yttrium400 Oct 31, 2024
37df9e5
Compute and store training loss and accuracy for each epoch.
yttrium400 Oct 31, 2024
879579e
Added validation loop structure with forward pass and loss computation.
yttrium400 Oct 31, 2024
606cb58
Compute and store validation loss and accuracy for each epoch.
yttrium400 Oct 31, 2024
5064317
Implemented logic to save the model with the best test accuracy.
yttrium400 Oct 31, 2024
5b15d1c
Added printing of training and validation results for each epoch.
yttrium400 Oct 31, 2024
613a3bb
Added plotting logic to visualize training and validation loss and ac…
yttrium400 Oct 31, 2024
3fbcdbf
Added informative comments to explain code structure and functionality.
yttrium400 Oct 31, 2024
9228820
Added device configuration for model and refactored imports.
yttrium400 Oct 31, 2024
2918b4a
cleaning up model initialization.
yttrium400 Oct 31, 2024
83f881e
Renamed function and switched to vit_small_patch16_224 with final adj…
yttrium400 Oct 31, 2024
7fd1975
Introduced constants and a helper function for image transformations.
yttrium400 Oct 31, 2024
5078895
Using get_transforms function in data loading logic.
yttrium400 Oct 31, 2024
eef098e
Implemented separate data loaders for training and validation with da…
yttrium400 Oct 31, 2024
750d1ac
Added separate test data loader and adjusted transformations for data…
yttrium400 Oct 31, 2024
a36f4bd
Added model for saving and testing logic with confusion matrix.
yttrium400 Oct 31, 2024
9772a11
Organized main function.
yttrium400 Oct 31, 2024
460a514
Initialized basic script structure with imports and placeholder funct…
yttrium400 Oct 31, 2024
8c951db
Added model loading and set to evaluation mode.
yttrium400 Oct 31, 2024
da45389
Added loading of test data using DataLoader.
yttrium400 Oct 31, 2024
5e65943
Added prediction logic to evaluate the model on the test dataset.
yttrium400 Oct 31, 2024
5105258
Calculating and print confusion matrix and test accuracy.
yttrium400 Oct 31, 2024
70328e2
Added plotting of confusion matrix for better visualization.
yttrium400 Oct 31, 2024
0e2175d
Added informative comments to explain code functionality and purpose.
yttrium400 Oct 31, 2024
293c28c
Added doc-strings to explain function parameters and the purpose of t…
yttrium400 Oct 31, 2024
bb72603
Final changes.
yttrium400 Oct 31, 2024
271544d
added plots for loss function.
yttrium400 Oct 31, 2024
fe0a0ab
Added covariance matrix plot
yttrium400 Oct 31, 2024
c8ad454
Updated the final documentation and README.
yttrium400 Oct 31, 2024
ee36b99
Training vs Testing Accuracy graph
yttrium400 Oct 31, 2024
e0b2c33
AD brain image
yttrium400 Oct 31, 2024
38816dd
Covariance Matrix
yttrium400 Oct 31, 2024
a56fe49
Training vs Validation loss
yttrium400 Oct 31, 2024
d13a73f
NC brain image
yttrium400 Oct 31, 2024
aac7b2d
Final Accuracy image
yttrium400 Oct 31, 2024
bffcedb
Visual Transformer Architecture
yttrium400 Oct 31, 2024
6a87a94
Update README.md
yttrium400 Oct 31, 2024
569ffac
Update README.md
yttrium400 Oct 31, 2024
9b3d5df
Merge branch 'topic-recognition' of https://github.com/yttrium400/Pat…
yttrium400 Oct 31, 2024
86cb1f9
README changes
yttrium400 Oct 31, 2024
9da53f5
Update README.md
yttrium400 Oct 31, 2024
040d9de
Update README.md
yttrium400 Oct 31, 2024
6522721
added docstrings and changed README based on tutor's feedback.
yttrium400 Nov 12, 2024
6753c44
Deleted .DS_Store
yttrium400 Nov 21, 2024
ca2d7a2
Deleted unwanted files to resolve merge conflict
yttrium400 Nov 21, 2024
abbd3f7
Deleted unwanted files to resolve merge conflict
yttrium400 Nov 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 0 additions & 10 deletions recognition/README.md

This file was deleted.

Binary file added recognition/vit_47415056/Images/AD.jpeg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added recognition/vit_47415056/Images/AD1.jpeg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added recognition/vit_47415056/Images/AD2.jpeg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added recognition/vit_47415056/Images/AD3.jpeg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added recognition/vit_47415056/Images/NC.jpeg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added recognition/vit_47415056/Images/NC1.jpeg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added recognition/vit_47415056/Images/NC2.jpeg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added recognition/vit_47415056/Images/NC3.jpeg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added recognition/vit_47415056/Images/loss_graph.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added recognition/vit_47415056/Images/result.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
200 changes: 200 additions & 0 deletions recognition/vit_47415056/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
# Alzheimer’s Disease Classification Using Vision Transformers (ViT)

**Student Number:** 47415056

**Name:** Swastik Lohchab

**Description:**
This project focuses on classifying Alzheimer’s Disease from MRI scans using Vision Transformers (ViT). The approach leverages ViT’s ability to capture spatial correlations across different regions of the brain, aiming to identify key areas associated with Alzheimer’s Disease. This project was conducted as part of the COMP3710 course at the University of Queensland and achieved a training accuracy of 67.78%, validation accuracy of 69.66% and test accuracy of 68.20%.


## Table of Contents
1. [How It Works](#1-how-it-works)
2. [Network Architecture](#2-network-architecture)
3. [Dependencies](#3-dependencies)
4. [Reproducibility](#4-reproducibility)
5. [How to Run](#5-how-to-run)
6. [Data Pre-Processing and Splits](#6-data-pre-processing-and-splits)
7. [Training and Evaluation](#7-training-and-evaluation)
8. [Results and Visualizations](#8-results-and-visualizations)
9. [Future Improvements](#9-future-improvements)
10. [Conclusion](#10-conclusion)
11. [References](#11-references)


## 1. How It Works

### Overview
The Vision Transformer (ViT) model takes an input image, divides it into patches, linearly embeds the patches, and feeds them into a series of Transformer layers to extract meaningful features for classification. In this project, we train the ViT model on a dataset of MRI images, each labeled as either ‘Normal’ or ‘Alzheimer’s Disease’.

The attention mechanism of the model underlines important brain areas that correlate with AD, hence giving insight into the model's decision-making process.

### Key Steps
1. Patch Embedding: The input image is divided into 16x16 patches, which are then flattened and linearly embedded to form a sequence.
2. Positional Encoding: Positional encodings are added to the patch embeddings to preserve spatial information.
3. Transformer Blocks: The embeddings are passed through multiple Transformer blocks, each consisting of a multi-head self-attention layer and a feed-forward neural network.
4. Classification Head: The final feature vector is used for binary classification, predicting either ‘Normal’ or ‘Alzheimer’s Disease’.


## 2. Network Architecture

### Key Components of Vision Transformer (ViT)
1. Patch Embedding Layer: Converts the input image into a sequence of flattened image patches, which can be processed by the Transformer.
2. Positional Encoding: Injects information about the spatial positions of the patches into the embeddings, as the Transformer is inherently permutation-invariant and does not have a built-in notion of order or position.
3. Transformer Blocks: The final output of each Transformer block is a refined set of patch embeddings, which are then passed to the next block in the sequence.
4. Classification Head: Transforms the output of the final Transformer block into class predictions.

### Steps of the Network Architecture
1. Input Image Processing
2. Adding Positional Encodings
3. Processing Through Transformer Blocks
4. Classification

### Benefits of Using Vision Transformers (ViT)
1. Unlike CNNs, which have a limited receptive field, ViT can capture long-range dependencies across the entire image using self-attention.
2. ViT can be easily scaled up or down by changing the number of patches, the embedding dimension, or the number of Transformer blocks.
3. Attention maps provide insights into which regions of the image are most important for the model’s decision-making process, making ViT a valuable tool for medical image analysis.

![ViT Architecture](Images/vit_architecture.jpg)


## 3. Dependencies
1. torch==2.0.1
2. torchvision==0.15.2
3. numpy==1.25.0
4. matplotlib==3.7.1
5. Pillow==9.4.0
6. scikit-learn==1.2.2 for performance metrics
7. scipy==1.11.1
8. python==3.12.4


## 4. Reproducibility

### Environment
1. Hardware: Training was conducted on Rangpur High-Performance Computing with NVidia A100 GPUs.
2. Software: Python 3.12 environment, using Anaconda.


## 5. How to Run

1. Clone the repository:

```
git clone https://github.com/yttrium400/PatternAnalysis-2024.git
cd PatternAnalysis-2024/recognition/vit_47415056
```

2. Train the model:

```
python train.py
```

This will train the ViT model on the ADNI dataset and save the trained model to `model_weights.pth` within the base vit_47415056 folder.

3. Prediction:

```
python predict.py
```


## 6. Data Pre-Processing and Splits

### Pre-Processing Steps
1. Resizing: All MRI images are resized to 224x224 pixels to ensure uniform input size, which is compatible with the Vision Transformer model.
2. Normalization: The pixel values are normalized using a mean of 0.1415 and a standard deviation of 0.2420, helping to standardize the input data and improve model convergence.
3. Data Augmentation: Various augmentation techniques are applied to increase data variability and reduce overfitting:
• Random Horizontal Flip: Flips images horizontally with a 50% probability.
• Random Vertical Flip: Adds further variability by flipping images vertically.
• Random Resized Crop: Randomly crops and resizes images to add randomness to the input images.
• Adjusting Sharpness: Modifies the sharpness of images to simulate different imaging conditions.
4. ToTensor Conversion: Images are converted to PyTorch tensors, which is necessary for inputting data into the model.
5. Shuffling: The training data is shuffled to ensure that the model does not learn any unintended patterns based on the order of the images.

### Data Structure

```
AD_NC
├── test
│ ├── AD
│ └── NC
└── train
├── AD
└── NC
```

### Images for AD
![AD image from train set](Images/AD.jpeg)
![AD image from train set](Images/AD1.jpeg)
![AD image from train set](Images/AD2.jpeg)
![AD image from train set](Images/AD3.jpeg)

### Images for NC
![NC image from train set](Images/NC.jpeg)
![NC image from train set](Images/NC1.jpeg)
![NC image from train set](Images/NC2.jpeg)
![NC image from train set](Images/NC3.jpeg)

### Splitting Strategy
The training set was further divided into 90% for training and 10% for validation. The split ensured that images from the same patient were not present in both subsets to maintain data integrity.


## 7. Training and Evaluation

### Configuration
1. Model: Vision Transformer (ViT) - "vit_small_patch16_224"
2. Optimizer: Adam with learning rate of 1e-5 and StepLR scheduler
3. Batch Size: 32
4. Number of Epochs: 10
5. Loss Function: Cross-Entropy Loss
6. Early Stopping: Triggered if validation loss did not improve for 7 epochs

### Training Loop
The training loop monitored both accuracy and loss metrics. Early stopping was implemented to prevent overfitting.


## 8. Results and Visualizations

### Performance
The final model achieved a training accuracy of 67.78%, validation accuracy of 69.66% and test accuracy of 68.20%.
![Training, Validation and Testing Accuracy](Images/result.jpg)

### Training and Validation Plots
1. Accuracy vs. Epochs: ![Training and Testing Accuracy VS No. of epochs graph](Images/accuracy_graph.jpg)
2. Loss vs. Epochs: ![Training and Testing loss VS No. of epochs graph](Images/loss_graph.jpg)

### Confusion Matrix
The confusion matrix provides insights into the model’s classification performance:

![Covariance Matrix](Images/covariance_matrix.jpg)


## 9. Future Improvements
1. Data Augmentation: Explore additional augmentation techniques to further improve model generalization.
2. Hyperparameter Tuning: Experiment with different learning rates, batch sizes, and Transformer configurations.
3. Attention Analysis: Conduct a more in-depth analysis of the attention maps to understand the model’s focus areas better.


## 10. Conclusion
This project successfully implemented a Vision Transformer (ViT) to classify Alzheimer’s Disease from MRI scans, achieving a test accuracy of 68.20%. While the model shows promise, the accuracy is limited by factors such as the complexity of MRI data and limited dataset size.

### Analysis of Training and Validation Graphs
1. Accuracy vs. Epochs: The accuracy graph shows a steady improvement in both training and validation accuracy over the epochs, indicating that the model is learning meaningful features from the data. However, there are noticeable fluctuations in the validation accuracy, which could be due to overfitting, where the model starts to memorize the training data instead of generalizing well to unseen data.
2. Loss vs. Epochs: The loss graph demonstrates a consistent decrease in training loss, but the validation loss does not drop as steadily. The gap between training and validation loss suggests that the model might be overfitting. Despite early stopping, there is still a risk that the model’s performance could degrade on new data, highlighting the need for more robust regularization techniques or a larger dataset.

### Analysis of the Confusion Matrix
The confusion matrix reveals additional insights into the model’s performance:

1. True Positives (Top-Left: 3790): The number of correctly classified Normal cases is relatively high, indicating that the model is good at identifying non-Alzheimer’s patients.
2. True Negatives (Bottom-Right: 2160): The number of correctly identified Alzheimer’s cases is also significant but lower than the true positives, suggesting the model is somewhat less effective at detecting Alzheimer’s.
3. False Positives (Top-Right: 670): These are cases where the model incorrectly classified Normal patients as having Alzheimer’s. While relatively fewer in number, false positives could lead to unnecessary concern or medical procedures.
4. False Negatives (Bottom-Left: 2380): The false negatives are concerning, as these represent Alzheimer’s cases that the model failed to identify. In a medical context, such errors are particularly critical, as undiagnosed Alzheimer’s Disease could delay necessary treatment or intervention.


## 11. References
1. Visual Transformer Architecture (ViT) - (https://huggingface.co/docs/transformers/model_doc/vit)
2. ADNI Dataset - (https://adni.loni.usc.edu)
3. ViT Overview - (https://www.geeksforgeeks.org/vision-transformer-vit-architecture/)
4. chatgpt - (https://chatgpt.com/c/6722567c-918c-800d-aae9-045e4d1dbf33)
74 changes: 74 additions & 0 deletions recognition/vit_47415056/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split

# Constants for image processing
IMAGE_SIZE = 224
BATCH_SIZE = 32

def get_transforms(is_train=True):
"""
Returns transformations for training or testing images.

Args:
is_train (bool): True for training transforms, False for testing.

Returns:
torchvision.transforms.Compose: Composed transformations.
"""
if is_train:
# Data augmentation for training
return transforms.Compose([
transforms.RandomResizedCrop(IMAGE_SIZE),
transforms.RandomHorizontalFlip(),
transforms.RandomAdjustSharpness(sharpness_factor=0.9, p=0.1),
transforms.ToTensor(),
transforms.Normalize(mean=[0.1415] * 3, std=[0.2420] * 3),
])
else:
# Basic transforms for testing
return transforms.Compose([
transforms.Resize(IMAGE_SIZE),
transforms.CenterCrop(IMAGE_SIZE),
transforms.ToTensor(),
transforms.Normalize(mean=[0.1415] * 3, std=[0.2420] * 3),
])

def get_train_val_loaders(data_dir):
"""
Creates DataLoaders for training and validation sets.

Args:
data_dir (str): Directory with image data.

Returns:
tuple: (train_loader, val_loader) for training and validation.
"""
transform = get_transforms(is_train=True)
dataset = datasets.ImageFolder(root=data_dir, transform=transform)

# Split dataset into training and validation sets
train_size = int(0.9 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=6)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=6)
return train_loader, val_loader

def get_test_loader(data_dir):
"""
Creates a DataLoader for the test set.

Args:
data_dir (str): Directory with test image data.

Returns:
DataLoader: DataLoader for the test set.
"""
transform = get_transforms(is_train=False)
dataset = datasets.ImageFolder(root=data_dir, transform=transform)

# DataLoader for test dataset
return DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=6)
21 changes: 21 additions & 0 deletions recognition/vit_47415056/modules.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import torch
from timm import create_model

# Configure the device to use GPU if available, otherwise use CPU
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def initialize_model():
"""
Initializes and returns a Vision Transformer (ViT) model.

The model uses a smaller architecture variant, 'vit_small_patch16_224',
without pretrained weights and is configured for 2 output classes.

Returns:
- model (torch.nn.Module): ViT model moved to the specified device.
"""

# Create a Vision Transformer (ViT) model using 'vit_small_patch16_224'
model = create_model("vit_small_patch16_224", pretrained=False, num_classes=2)

return model.to(DEVICE)
60 changes: 60 additions & 0 deletions recognition/vit_47415056/predict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import os
import torch
from sklearn.metrics import confusion_matrix, accuracy_score
import matplotlib.pyplot as plt
from modules import initialize_model
from dataset import get_test_loader
import numpy as np

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Create the folder structure if it does not exist
output_dir = "vit_47415056/graphs"
os.makedirs(output_dir, exist_ok=True)

def predict_and_visualize(model_path, test_data_dir):
"""
Load a model, make predictions on test data, and visualize the results.

Parameters:
- model_path (str): Path to the saved model weights.
- test_data_dir (str): Directory path for test data.
"""

model = initialize_model()
model.load_state_dict(torch.load(model_path))
model.eval()

test_loader = get_test_loader(test_data_dir)
all_preds, all_labels = [], []

print("\nTesting the model again on the test dataset...\n")
for images, labels in test_loader:
images = images.to(DEVICE)
with torch.no_grad():
predictions = model(images).argmax(dim=1)
all_preds.extend(predictions.cpu().numpy())
all_labels.extend(labels.numpy())

conf_matrix = confusion_matrix(all_labels, all_preds)
accuracy = accuracy_score(all_labels, all_preds)
print(f"Confusion Matrix:\n{conf_matrix}\nTest Accuracy: {accuracy:.2%}")

# Plotting the confusion matrix as a covariance matrix (heatmap) with numbers
plt.figure()
plt.matshow(conf_matrix, cmap='viridis')
plt.title("Confusion Matrix (Covariance Matrix)")
plt.colorbar()
plt.xlabel("Predicted")
plt.ylabel("Actual")

# Add numbers to each cell
for (i, j), value in np.ndenumerate(conf_matrix):
plt.text(j, i, f"{value}", ha="center", va="center", color="white")

plt.savefig(os.path.join(output_dir, "covariance_matrix.png"))

if __name__ == "__main__":
model_path = "model_weights.pth"
test_data_dir = "/home/groups/comp3710/ADNI/AD_NC/test"
predict_and_visualize(model_path, test_data_dir)
Loading