Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Topic recognition #469

Open
wants to merge 63 commits into
base: topic-recognition
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
63 commits
Select commit Hold shift + click to select a range
d4943af
Initial commit and filling out basic project structure.
4lon Oct 11, 2022
32876fb
Added basic dataset class that loads all images in directory and retu…
4lon Oct 11, 2022
0c9dc73
Added template for VAE
4lon Oct 11, 2022
0d501b0
Corrected error in dataset class location, moved from train to dataset.
4lon Oct 11, 2022
4176432
Added Vector Quantizer
4lon Oct 14, 2022
067b1fc
Updated encoder and decoders
4lon Oct 14, 2022
d9e43b3
Fixed wrong layer errors and added vqvae model, likely need to fix pa…
4lon Oct 14, 2022
5cb7a40
Updated to use pytorch embeddings and fleshed out vector quantizer.
4lon Oct 20, 2022
662ef51
Cleanup
4lon Oct 20, 2022
d2611f7
Cleanup and redid vq_vae class to add more individual functions. This…
4lon Oct 20, 2022
eb289e9
Updated to make dataloaders in dataset.py
4lon Oct 20, 2022
dbea67a
Basis of training script setting up
4lon Oct 20, 2022
43cc875
Built out saving model and modularised functions
4lon Oct 20, 2022
6a60fba
Fixed up incorrect dataset formatting and loading of imgs
4lon Oct 20, 2022
40bb45c
Fleshed out training function
4lon Oct 20, 2022
eecfde1
Fleshed out testing function
4lon Oct 20, 2022
aa65485
Added image generation and saving to tensorboard for visual analysis
4lon Oct 20, 2022
9ef9451
Refactored modules to include original functions file and updated dat…
4lon Oct 20, 2022
670cc67
Increase HPC utilisation
4lon Oct 20, 2022
3c7ebce
add checks
4lon Oct 20, 2022
968ad8a
Increased Epochs
4lon Oct 20, 2022
cd008ee
Original pytorch attempt was not working, start of tensorflow attempt…
4lon Oct 20, 2022
eee3087
Old train model, not relevant anymore but good for looking back and n…
4lon Oct 20, 2022
65a76a4
Redesigned VQVAE modules for tensorflow. This is a simpler design ove…
4lon Oct 20, 2022
c54bd95
Reimplemented basic training method, but this one does not evaluate a…
4lon Oct 20, 2022
00490b4
Fixed some errors encountered when trying to train because of tracker…
4lon Oct 20, 2022
09ca20c
Added plotting of results including losses and reconstruction post le…
4lon Oct 21, 2022
53811cf
Updated directory
4lon Oct 21, 2022
bd443ee
Fixed missing param
4lon Oct 21, 2022
120f5e1
Train on more images
4lon Oct 21, 2022
d7400d1
Updated optional param
4lon Oct 21, 2022
07950f1
Imporved plotting
4lon Oct 21, 2022
e65e287
Removed print statement
4lon Oct 21, 2022
36a4200
Fixed image formatting
4lon Oct 21, 2022
ba208a4
New dataset dtype to remove reconstruction issues.
4lon Oct 21, 2022
178c251
Reduced batch size because of memory running out in new datatype (goi…
4lon Oct 21, 2022
f6eea6d
Reduced batch size again.
4lon Oct 21, 2022
ed9a611
Changed data type because space requirement was unachievable
4lon Oct 21, 2022
d32fda4
Built basic masked convolution layer and limited dataset size (memory…
4lon Oct 21, 2022
12bbfa6
Implemented residual block for pixelcnn
4lon Oct 21, 2022
31bcdef
Implemented basic pixel cnn architecture but not sure it works.
4lon Oct 21, 2022
c0ede93
Added pixelcnn training to learn codebook production
4lon Oct 21, 2022
cd4838a
Updated plotting to also save pixel cnn performance
4lon Oct 21, 2022
3660c53
Updated training with ssim metrics
4lon Oct 21, 2022
b1e9fda
reduced dataset size
4lon Oct 21, 2022
fafcd18
reduced dataset size
4lon Oct 21, 2022
b9d4ce0
reduced dataset size
4lon Oct 21, 2022
2de0a80
reduced dataset size
4lon Oct 21, 2022
570a4b6
Added sample model outputs
4lon Oct 21, 2022
df901ca
Added better sample model outputs
4lon Oct 21, 2022
14d2ebb
Refactored training to seperate pixel cnn and vqvae in case of sepera…
4lon Oct 21, 2022
b193525
Fleshing out report
4lon Oct 21, 2022
a15a457
Fleshing out report
4lon Oct 21, 2022
19d4895
Add brain generating function but doesn't really work.
4lon Oct 21, 2022
9de9e81
Removed old attempt in pytorch and added snapshots of inference.
4lon Oct 21, 2022
3fb1dfc
Fixed up incorrect function call from refactor and add SSIM evaulatio…
4lon Oct 21, 2022
fcbbb82
Finalising report
4lon Oct 21, 2022
0a71411
Fixed reconstruction a bit
4lon Oct 21, 2022
0231d01
Merge branch 'topic-recognition' into topic-recognition
4lon Nov 24, 2022
8be7c8e
Delete recognition/44801582_OASIS_VAE/samples/pixelcnn_model directory
4lon Nov 24, 2022
67ca8b2
Delete recognition/44801582_OASIS_VAE/samples/vqvae_model directory
4lon Nov 24, 2022
949fd52
Delete pixelcnn_model_weights.h5
4lon Nov 24, 2022
308597f
Delete vqvae_model_weights.h5
4lon Nov 24, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -132,4 +132,4 @@ dmypy.json
.idea/

# no tracking mypy config file
mypy.ini
mypy.ini
70 changes: 70 additions & 0 deletions recognition/44801582_OASIS_VAE/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# COMP3710 PatternFlow Report
## Alon Nusem - s4480158
### Project: VQVAE on the Oasis brain dataset

## Project Overview
### The VQVAE Model
A variational autoencoder is an auto encoder with modifier training that ensures a latent space is generated with better properties than just a regular autoencoder. This can help avoid overfitting and is done by returning a distribution over the latent space and adding a loss function term based on regularisation [1].

A vector quantized vae is a form of vae that uses vector quantisation to obtain this discrete latent representation.

This model in essence is a more effective way to compress, and then uncompress data into a latent space and restore it with signicant accuracy, aiming for a structured similarity over 0.6.

### The Dateset
The dataset being used for this analysis is the OASIS brain data set, captured during the OASIS brain study. This is an expansive dataset seperated into 544 test images, 1,120 validation images, and 9,664 training images of cross sectional MRI images of brains.

### The Goal
Using a vector quantized variational autoencoder, the dataset can be analysed and reduced into a more dense latent space. Training this VQVAE allows the model to essentially compress and then uncompress inputs accurately. Following this, a generational network can be designed, and by combining this network, feeding its output into the decoder from the VQVAE, new images can be created from this latent space.

## Results
After training the VQVAE on a subset of the training dataset, the model was evaluated on an unseen section of the test datset. Below is a sample of 8 brains after being reconstructed from encoding, vector quantization, and decoding after 10 epochs.

![](samples/reconstruction.png)

During this run, a SSIM of 0.91 on a sample 500 of the test dataset images. This can be re-evaluated in the predict script. While running the training the following loss plots were produced:

VQVAE:\
![](samples/training_loss_curves_vq_vae.png)

PixelCNN:\
![](samples/training_loss_curves_pixelcnn.png)

Both of these come to a plateau which suggests that there likely isn't much that more epochs of training would do. Adding more data may benefit but I touch on this in the final section.

However while these models seem to train well and VQVAE does function, there must be some issue either with pixelCNN or the generation code as new brains cannot be produced well

![](samples/Figure_2.png)

![](samples/Figure_3.png)

I'm not sure where this implementation went wrong and it requires further analysis but it does illustrate how a low dimensionality code can be transformed into a arguably more brainlike reproduction.


## How to setup this project
### Dependancies
- Python 3.9
- tensorflow=2.10.0
- tensorflow-probability=0.18.0
- scikit-image=0.18.1
- matplotlib-base=3.3.4
- numpy-base=1.21.5
- pillow=9.0.1

### Steps for reproducing
1. Setup a new conda environment with the dependancies listed above.
2. Download the dataset from https://cloudstor.aarnet.edu.au/plus/s/tByzSZzvvVh0hZA (This is a preprocessed set from the OASIS dataset, it also includes segmentation masks but that isn't necessary for us)
3. Extract the dataset into a folder labelled data in the 44801582_OASIS_VAE directory
4. Use train.py to train vqvae or pixelcnn
5. Use prediction.py to generate new brains (default run uses samples provided, adjustment is needed if you want to use your own model results)

## How to improve on these results (and issues)
First thing is that the current implementation of dataset is not great. It loads everything into a numpy array which is super space intensive and makes training crash unless you limit the input data size. This needs updating as it would improve training process dramatically.

The pixelCNN and VQVAE could both be improved, model wise they are not very complex implementations of their base form, better performance is possible.

## Resources
[1] https://towardsdatascience.com/understanding-variational-autoencoders-vaes-f70510919f73

[] https://github.com/ritheshkumar95/pytorch-vqvae

[] https://keras.io/examples/generative/vq_vae/
35 changes: 35 additions & 0 deletions recognition/44801582_OASIS_VAE/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import os
import numpy as np
from PIL import Image


def load_data(path, img_limit=False, want_var=False):
dataset = []

for i, img in enumerate(os.listdir(path)):
if img_limit and i > img_limit:
break
else:
img = Image.open(f"{path}/{img}")
data = np.asarray(img, dtype=np.float32)
dataset.append(data)

dataset = np.array(dataset, dtype=np.float32)

if want_var:
data_variance = np.var(dataset / 255.0)
else:
data_variance = None

dataset = np.expand_dims(dataset, -1)
dataset = (dataset / 255.0) - 0.5

return dataset, data_variance


def oasis_dataset(images= False):
train, variance = load_data("data/keras_png_slices_data/keras_png_slices_train", images, True)
test, _ = load_data("data/keras_png_slices_data/keras_png_slices_test", images)
validate, _ = load_data("data/keras_png_slices_data/keras_png_slices_validate", images)

return train, test, validate, variance
115 changes: 115 additions & 0 deletions recognition/44801582_OASIS_VAE/modules.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
import tensorflow as tf
from tensorflow import keras
import numpy as np


class VectorQuantizer(keras.layers.Layer):
def __init__(self, num_embeddings, embedding_dim, beta=0.25, **kwargs):
super().__init__(**kwargs)
self.embedding_dim = embedding_dim
self.num_embeddings = num_embeddings

self.beta = beta

w_init = tf.random_uniform_initializer()
self.embeddings = tf.Variable(
initial_value=w_init(shape=(self.embedding_dim, self.num_embeddings), dtype="float32"),
trainable=True,
name="embeddings_vqvae")

def call(self, x):
encoding_indices = self.get_code_indices(tf.reshape(x, [-1, self.embedding_dim]))
encodings = tf.one_hot(encoding_indices, self.num_embeddings)
quantized = tf.reshape(tf.matmul(encodings, self.embeddings, transpose_b=True), tf.shape(x))

self.add_loss(self.beta * tf.reduce_mean((tf.stop_gradient(quantized) - x) ** 2)
+ tf.reduce_mean((quantized - tf.stop_gradient(x)) ** 2))

quantized = x + tf.stop_gradient(quantized - x)

return quantized

def get_code_indices(self, flattened_inputs):
similarity = tf.matmul(flattened_inputs, self.embeddings)
encoding_indices = tf.argmin((tf.reduce_sum(flattened_inputs ** 2, axis=1, keepdims=True)
+ tf.reduce_sum(self.embeddings ** 2, axis=0) - 2 * similarity), axis=1)

return encoding_indices


def def_encoder(latent_dim):
encoder_inputs = keras.Input(shape=(256, 256, 1))
x = keras.layers.Conv2D(32, 3, activation="relu", strides=2, padding="same")(
encoder_inputs
)
x = keras.layers.Conv2D(64, 3, activation="relu", strides=2, padding="same")(x)
encoder_outputs = keras.layers.Conv2D(latent_dim, 1, padding="same")(x)
return keras.Model(encoder_inputs, encoder_outputs, name="encoder")


def def_decoder(latent_dim):
latent_inputs = keras.Input(shape=def_encoder(latent_dim).output.shape[1:])
x = keras.layers.Conv2DTranspose(64, 3, activation="relu", strides=2, padding="same")(
latent_inputs
)
x = keras.layers.Conv2DTranspose(32, 3, activation="relu", strides=2, padding="same")(x)
decoder_outputs = keras.layers.Conv2DTranspose(1, 3, padding="same")(x)
return keras.Model(latent_inputs, decoder_outputs, name="decoder")


def VQVAE(latent_dim=16, num_embeddings=64):
vq_layer = VectorQuantizer(num_embeddings, latent_dim, name="vector_quantizer")
encoder = def_encoder(latent_dim)
decoder = def_decoder(latent_dim)
inputs = keras.Input(shape=(256, 256, 1))
encoded = encoder(inputs)
quantized = vq_layer(encoded)
reconstructions = decoder(quantized)
return keras.Model(inputs, reconstructions, name="vq_vae")


class MaskedConvLayer(keras.layers.Layer):
def __init__(self, **kwargs):
super(MaskedConvLayer, self).__init__()
self.convolution = keras.layers.Conv2D(**kwargs)

def build(self, input_shape):
self.convolution.build(input_shape)
kernel_shape = self.convolution.kernel.get_shape()
self.mask = np.zeros(shape=kernel_shape)
self.mask[: kernel_shape[0] // 2, ...] = 1.0
self.mask[kernel_shape[0] // 2, : kernel_shape[1] // 2, ...] = 1.0

def call(self, inputs):
self.convolution.kernel.assign(self.convolution.kernel * self.mask)
return self.convolution(inputs)


class ResidBlock(keras.layers.Layer):
def __init__(self, filters, **kwargs):
super(ResidBlock, self).__init__(**kwargs)
self.c1 = keras.layers.Conv2D(filters, 1, activation="relu")
self.mc = MaskedConvLayer(filters=filters // 2, kernel_size=3, activation="relu", padding="same")
self.c2 = keras.layers.Conv2D(filters, 1, activation="relu")

def call(self, inputs):
x = self.c1(inputs)
x = self.mc(x)
x = self.c2(x)
return keras.layers.add([inputs, x])


def PixelCNN(latent_dim, num_embeddings, num_residual_blocks, num_pixelcnn_layers):
inputs = keras.Input(def_encoder(latent_dim).layers[-1].output_shape, dtype=tf.int32)
encoding = tf.one_hot(inputs, num_embeddings)
x = MaskedConvLayer(filters=128, kernel_size=7, activation="relu", padding="same")(encoding)

for _ in range(num_residual_blocks):
x = ResidBlock(128)(x)
for _ in range(num_pixelcnn_layers):
x = MaskedConvLayer(filters=128, kernel_size=1, activation="relu", padding="valid")(x)

output = keras.layers.Conv2D(num_embeddings, 1, 1, padding="valid")(x)
pixel_cnn = keras.Model(inputs, output)

return pixel_cnn
107 changes: 107 additions & 0 deletions recognition/44801582_OASIS_VAE/predict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
from tensorflow import keras
import tensorflow as tf
import tensorflow_probability as tfp
import numpy as np
import dataset
import modules
import matplotlib.pyplot as plt
from skimage.metrics import structural_similarity


def create_new_brains():
(train_data, validate_data, test_data, data_variance) = dataset.oasis_dataset(images=10)

vqvae = modules.VQVAE(16, 128)
vqvae.load_weights("samples/vqvae_model_weights.h5")

pixelcnn = modules.PixelCNN(16, 128, 2, 2)
pixelcnn.load_weights("samples/pixelcnn_model_weights.h5")

inputs = keras.layers.Input(shape=(64, 64, 16))
outputs = pixelcnn(inputs, training=False)
categorical_layer = tfp.layers.DistributionLambda(tfp.distributions.Categorical)
outputs = categorical_layer(outputs)
sampler = keras.Model(inputs, outputs)

batch = 10
rows = 64
cols = 64
priors = np.zeros(shape=(batch, rows, cols))

for row in range(rows):
for col in range(cols):
priors[:, row, col] = sampler.predict(priors)[:, row, col]
print(f"{(row + 1)*(col + 1) + (col)}/{64*64}")

pretrained_embeddings = vqvae.get_layer("vector_quantizer").embeddings
one_hot = tf.one_hot(priors.astype("int32"), 128).numpy()
quantized = tf.reshape(tf.matmul(one_hot.astype("float32"),
pretrained_embeddings, transpose_b=True), (-1, *(64, 64, 16)))

generated_samples = vqvae.get_layer("decoder").predict(quantized)

for i in range(batch):
plt.subplot(1, 2, 1)
plt.imshow(priors[i], cmap='gray')
plt.title("Code")
plt.axis("off")

plt.subplot(1, 2, 2)
plt.imshow(generated_samples[i].squeeze() + 0.5, cmap='gray')
plt.title("Generated Sample")
plt.axis("off")
plt.show()


def get_structural_similarity():
vqvae = modules.VQVAE(16, 128)
vqvae.load_weights("samples/vqvae_model_weights.h5")
_, _, test_data, _ = dataset.oasis_dataset(500)

similarity_scores = []
reconstructions_test = vqvae.predict(test_data)

for i in range(reconstructions_test.shape[0]):
original = test_data[i, :, :, 0]
reconstructed = reconstructions_test[i, :, :, 0]

similarity_scores.append(structural_similarity(original, reconstructed,
data_range=original.max() - original.min()))

average_similarity = np.average(similarity_scores)

print(average_similarity)


def plot_reconstructions():
vqvae = modules.VQVAE(16, 128)
vqvae.load_weights("samples/vqvae_model_weights.h5")
_, _, test_data, _ = dataset.oasis_dataset(500)

num_tests = 4
test_images = test_data[np.random.choice(len(test_data), num_tests)]
reconstructions = vqvae.predict(test_images)

i = 0
plt.figure(figsize=(num_tests * 2, 4), dpi=512)
for test_image, reconstructed_image in zip(test_images, reconstructions):
test_image = test_image.squeeze()
reconstructed_image = reconstructed_image[:, :, 0]
plt.subplot(num_tests, 2, 2 * i + 1, )
plt.imshow(test_image, cmap='gray')
plt.title("Original")
plt.axis("off")

plt.subplot(num_tests, 2, 2 * i + 2)
plt.imshow(reconstructed_image, cmap='gray')
plt.title(f"Reconstructed (SSIM:{structural_similarity(test_image, reconstructed_image, data_range=test_image.max() - test_image.min()):.2f})")

plt.axis("off")

i += 1

plt.show()


if __name__ == "__main__":
plot_reconstructions()
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading