Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Topic recognition #436

Open
wants to merge 26 commits into
base: topic-recognition
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
2b1ef28
initial commit made necessary files
jbart178 Oct 15, 2022
bc8ad99
made vector quantizer
jbart178 Oct 16, 2022
70a3a22
made rough vqvae model
jbart178 Oct 16, 2022
c7fc18f
edited gitignore to ignore dataset stored locally
jbart178 Oct 16, 2022
eefebeb
fixed typo
jbart178 Oct 16, 2022
ebb6ff0
loaded datasets and applied scaling
jbart178 Oct 16, 2022
02dafd4
made dataloading function and littl preprocessing
jbart178 Oct 17, 2022
8083f66
made model trainer
jbart178 Oct 17, 2022
bf5fabb
bug fixing and training testing
jbart178 Oct 17, 2022
3c16cf4
refactoring to modules
jbart178 Oct 17, 2022
6c15f7d
added residual layers and refactored accordingly
jbart178 Oct 17, 2022
b9fa0e9
fixed a silly typo
jbart178 Oct 17, 2022
fa55571
edit gitignore, added save model
jbart178 Oct 19, 2022
5696e84
add multiprocessing option to dataloading
jbart178 Oct 19, 2022
5c1424e
variable residual hidden layers added and add png to gitignnore
jbart178 Oct 19, 2022
f607136
made pixelcnn and trainer
jbart178 Oct 20, 2022
837ad87
made sampler and testing for pixel cnn
jbart178 Oct 20, 2022
6059faa
code refactoring
jbart178 Oct 20, 2022
949504e
code refactoring
jbart178 Oct 20, 2022
b075b27
made driver script
jbart178 Oct 20, 2022
646eee0
refactoring making more modular and nonhardcoded
jbart178 Oct 20, 2022
41cc6bf
code refactoring and hyperparamerts
jbart178 Oct 20, 2022
d7d1952
made readme and edited model parameters
jbart178 Oct 20, 2022
18144d2
hyperparameters and vqvae residual block edit
jbart178 Oct 21, 2022
543fa9d
final changes to readme and driver
jbart178 Oct 21, 2022
84c2fb1
something changed
jbart178 Oct 23, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
recognition/s4481540_Zhuoxiao_Chen/data/
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
Expand Down Expand Up @@ -120,10 +119,7 @@ venv.bak/
# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/
encoder_out.pyre/

# vscode config file
.vscode/
Expand Down
10 changes: 10 additions & 0 deletions recognition/45819061-VQVAE-OASIS/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@

# ignore data folder
data/

# constructed model
vqvae/
pixelcnn/

# all images
**.png
42 changes: 42 additions & 0 deletions recognition/45819061-VQVAE-OASIS/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Vector Quantized - Variational Autoencoder for Generation of OASIS Brain data


Here we construct a Vector Quantized - Variational Autoencoder (VQ-VAE)model trained on the OASIS brain dataset to construct a generative model which can reproduce and generate brain scan images by sampling a discrete latent space much smaller than the desired images.

# Problem
Development in computer technology in recognising and classifying brain disease is a growing field which aims to develop effective computer models that can recognise and classify information in brain scans to identify problems and characteristics of a patients brain. A limitation in the effectiveness of this technology currently stems from an insufficient amount of data to train these classification models and thus the models that are produced are undertrained and ineffective. We use a VQ-VAE as a way of learning the structure and characteristics of brain scans and encoding into a smaller compact latent space. We learn patterns and structures of this latent space and train a generative model that generates clear and new brain scans which can be used to train these classification models.

# The Model
The model we train is a VQ-VAE consisting of an encoder feeding into a vector quantizer layer whose output then feeds into the decoder. The encoder and decoder are both made of to convolutional blocks and two residual layers. The convolutional layers are 4x4 windows with stride 2 and reduce the image data by a factor of four before passing to the residual layers. We use filter sizes 32, 64. Next, the residual layers are two convolutions (3x3 and 1x1) with filter size 32 and leaky relu activations between. The output of the residual block is the sum of the out put of this convolution wth the original data. Vector Quantizer layer consists of a codebook of embedding codes, the VQ layer takes the output of the encoder and computes relative distance to these embeddings to find the images supposed place in the latent space. VQ can be thought of as being given the identified key characteristics of the image by the encoder and then the VQ assigns the output the indices where such information is stored in the latent. Finally a decdoer takes a set of odewords from the latent space and via 2 transposed convolutional layers and residual blocks the image is rebuilt. During training the VQVAE attempts to maintain the integrity of its vector quantisation of the latent space and its reproduction of the image.
For generation of images we train a PixelCNN on the latent space discovered by the VQVAE to sample the latent space and discover new codes t pass to the decoder to generate realistic brain scans.
The model we design in developed was based on that described in [Paper](https://arxiv.org/abs/1711.00937).

# Requirements
Although versioning may not be strict this is what was used in this case.
- tensorflow = 2.10.0
- tensorflow-probability = 0.18.0
- tqdm = 4.64.1
- matplotlib = 3.6.1

# Training
We train the models with Adam optimizers tracking commitment loss, codebook loss and reconstruction loss in the case of the VQVAE, and categorical entropy in the case of the pixelcnn. The loss function for the VQ-VAE is described in [Paper](https://arxiv.org/abs/1711.00937) and is essentially the distance of the output of the model at various stages (after decode, after encode) to expected values at that point and is designed to improve the reconstruction clarity as well as keep the latent space meaningful and interpretablke by the later PixelCNN.


We train the model using the VQVAETrainer class which contains all the logic required for the training. In our experiement we trained the model over the entire training set given with the OASIS brain data for 50 epochs. Relevant parameters such as dimension of the embedding space and filter sizes for layers are given in the driver.py script which trains a VQVAE model, PixelCNN model and produces figures demonstrating training statistics, and expected outputs of the final model. We include our findings below

![](losses.png)
![](ssim.png)

Here we have some example input, output pairs for the auto encoder
![](fig1.png)
![](fig2.png)

And their representation in the codebook space in the bottleneck of the auto encoder
![](embedding1.png)
![](embedding2.png)

And the results of a pixel cnn given the following codebook data.
![](gen1.png)
![](gen2.png)
# Data
The data we used was this preprocessed OS brain data available here [Link](https://cloudstor.aarnet.edu.au/plus/s/tByzSZzvvVh0hZA). Since this data is already split into training, validation and testing sets we did not perform any dataset splitting. Before passing images to the model we normalised the encoding by loading as grayscale images and scaling all the values to be in the domain [-0.5, 0.5].
43 changes: 43 additions & 0 deletions recognition/45819061-VQVAE-OASIS/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import os
import tensorflow as tf
import numpy as np
from tqdm import tqdm




def reader(f):
return tf.io.decode_png(tf.io.read_file(f), channels=1)

def load(files, use_multiprocessing=False):
if use_multiprocessing:
import multiprocessing
pool = multiprocessing.Pool(use_multiprocessing)
lst = pool.map(reader, tqdm(files))
else:
lst = map(reader, tqdm(files))

imgs = np.asarray(list(lst), dtype='float32')
return imgs

"""
Load data from predefined paths TRAIN_DATA, TEST_DATA, VALIDATE_DATA.
optional argument use_multiprocessing defaults to false can specify and integer to spawn child
processes to load faster on machines with sufficient capabilities
"""
def get_data(train_dir, test_dir, validate_dir, use_multiprocessing=False):
files_train = [os.path.join(train_dir, f) for f in os.listdir(train_dir) if os.path.isfile(os.path.join(train_dir, f))]
files_test = [os.path.join(test_dir, f) for f in os.listdir(test_dir) if os.path.isfile(os.path.join(test_dir, f))]
files_validate = [os.path.join(validate_dir, f) for f in os.listdir(validate_dir) if os.path.isfile(os.path.join(validate_dir, f))]

print("Loading data")
x_train = load(files_train, use_multiprocessing)
x_test = load(files_test, use_multiprocessing)
x_validate = load(files_validate, use_multiprocessing)

# scale image data to [-1, 1] range
x_train = x_train/255.0 - 0.5
x_test = x_test/255.0 - 0.5
x_validate = x_validate/255.0 - 0.5

return x_train, x_test, x_validate
37 changes: 37 additions & 0 deletions recognition/45819061-VQVAE-OASIS/driver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from matplotlib import pyplot as plt
import tensorflow as tf
import numpy as np
from dataset import *
from modules import *
from predict import *
from train import *

VQVAE_DIR = "vqvae"
PIXELCNN_DIR = "pixelcnn"
LATENT_DIM = 32
NUM_EMBEDDINGS = 128
RESIDUAL_HIDDENS = 32
EPOCHS = 50
BATCH_SIZE = 64
DATA_DIR = 'data/keras_png_slices_data'
TRAIN_DATA = DATA_DIR + '/keras_png_slices_train'
TEST_DATA = DATA_DIR + '/keras_png_slices_test'
VALIDATE_DATA = DATA_DIR + '/keras_png_slices_validate'
#model = tf.keras.models.load_model(VQVAE_DIR, custom_objects={'VectorQuantizer': VectorQuantizer})
#pixelcnn = tf.keras.models.load_model(PIXELCNN_DIR, custom_objects={'PixelCNN': PixelCNN, 'ResidualBlock': ResidualBlock})

x_train, x_test, x_validate = get_data(TRAIN_DATA, TEST_DATA, VALIDATE_DATA)
model = train(x_train, x_test, x_validate,
epochs=EPOCHS, batch_size=BATCH_SIZE, out_dir=VQVAE_DIR,
latent_dim=LATENT_DIM,
num_embeddings=NUM_EMBEDDINGS,
residual_hiddens=RESIDUAL_HIDDENS
)

demo_vqvae(model, x_test)

pixelcnn = pixelcnn_train(model, x_train, x_test, x_validate,
epochs=EPOCHS, batch_size=BATCH_SIZE, out_dir=PIXELCNN_DIR,
num_embeddings=NUM_EMBEDDINGS
)
sample_images(model, pixelcnn)
139 changes: 139 additions & 0 deletions recognition/45819061-VQVAE-OASIS/modules.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
from base64 import decode
import code
from matplotlib.cbook import flatten
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Input, Layer, ReLU, Add, Conv2D, Conv2DTranspose
import tensorflow_probability as tfp

class VectorQuantizer(Layer):
def __init__(self, num_embeddings, embedding_dim, beta=0.25, name="VQ", **kwargs):
super().__init__(**kwargs)
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
self.beta = beta

# Initialise flattened embeddings
w_init = tf.random_uniform_initializer()
self.embeddings = tf.Variable(
initial_value=w_init(
shape=(self.embedding_dim, self.num_embeddings),
dtype='float32'
),
trainable=True,
name=name
)

def call(self, x):
input_shape = tf.shape(x)
flattened = tf.reshape(x, (-1, self.embedding_dim))

# Quantization
encoding_indices = self.get_code_indices(flattened)
encodings = tf.one_hot(encoding_indices, self.num_embeddings)
quantized = tf.matmul(encodings, self.embeddings, transpose_b=True)
quantized = tf.reshape(quantized, input_shape)

commitment_loss = tf.nn.l2_loss(tf.stop_gradient(quantized) - x)**2
codebook_loss = tf.nn.l2_loss(quantized - tf.stop_gradient(x))**2

self.add_loss(self.beta * commitment_loss + codebook_loss)

quantized = x + tf.stop_gradient(quantized - x)
return quantized

def get_code_indices(self, flattened_inputs):
similarity = tf.matmul(flattened_inputs, self.embeddings)
distances = (
tf.reduce_sum(flattened_inputs**2, axis=1, keepdims=True)
+ tf.reduce_sum(self.embeddings**2, axis=0)
- 2 * similarity
)

encoding_indices = tf.argmin(distances, axis=1)
return encoding_indices


def resblock(x, filters=256):
xconv = Conv2D(filters, 3, strides=1, activation='relu', padding='same')(x)
xconv = Conv2D(x.shape[-1], 1, strides=1, padding='same')(xconv)
out = Add()([x, xconv])
return ReLU()(out)

class PixelCNN(Layer):
def __init__(self, mask_type, **kwargs):
super(PixelCNN, self).__init__()
self.mask_type = mask_type
self.conv = Conv2D(**kwargs)

def build(self, input_shape):
self.conv.build(input_shape)
kernel_shape = self.conv.kernel.get_shape()
self.mask = np.zeros(shape=kernel_shape)
self.mask[:kernel_shape[0]//2, ...] = 1.0
self.mask[kernel_shape[0]//2, :kernel_shape[1]//2, ...] = 1.0
if self.mask == 'B':
self.mask[kernel_shape[0]//2, kernel_shape[1]//2, ...] = 1.0

def call(self, inputs):
self.conv.kernel.assign(self.conv.kernel * self.mask)
return self.conv(inputs)

class ResidualBlock(Layer):
def __init__(self, filters):
super(ResidualBlock, self).__init__()
self.conv1 = Conv2D(filters=filters, kernel_size=1, activation='leaky_relu')
self.pixelcnn = PixelCNN(mask_type='B', filters=filters//2, kernel_size=3, activation='leaky_relu', padding='same')
self.conv2 = Conv2D(filters=filters, kernel_size=1, activation='leaky_relu')

def call(self, inputs):
x = self.conv1(inputs)
x = self.pixelcnn(x)
x = self.conv2(x)
return tf.keras.layers.add([inputs, x])

def get_pixelcnn(input_shape, num_embeddings, filters=128, num_residual_blocks=2, num_pixelcnn_layers=2, **kwargs):
pixelcnn_inputs = Input(shape=input_shape, dtype=tf.int32)
onehot = tf.one_hot(pixelcnn_inputs, num_embeddings)
x = PixelCNN(mask_type='A', filters=filters, kernel_size=32, activation='leaky_relu', padding='same')(onehot)
for _ in range(num_residual_blocks):
x = ResidualBlock(filters=filters)(x)
for _ in range(num_pixelcnn_layers):
x = PixelCNN(mask_type='B', filters=filters, kernel_size=1, strides=1, activation='leaky_relu', padding='valid')(x)
out = Conv2D(filters=num_embeddings, kernel_size=1, strides=1, padding="valid")(x)
return tf.keras.Model(pixelcnn_inputs, out, name='pixelcnn')


def get_vqvae(latent_dim=16, num_embeddings=64, input_shape=(256, 256, 1), residual_hiddens=64):
latent_dim = latent_dim
num_embeddings = num_embeddings

# Build encoder
encoder_in = Input(shape=input_shape)
x = Conv2D(32, 3, strides=2, activation='leaky_relu', padding='same')(encoder_in)
x = Conv2D(64, 3, strides=2, activation='leaky_relu', padding='same')(x)
encoder_out = Conv2D(latent_dim, 1, padding="same")(x)
encoder = tf.keras.Model(encoder_in, encoder_out, name='encoder')

# Build decoder
decoder_in = Input(shape=encoder.output.shape[1:])
y = Conv2DTranspose(64, 3, strides=2, activation='leaky_relu', padding='same')(decoder_in)
y = Conv2DTranspose(32, 3, strides=2, activation='leaky_relu', padding='same')(y)
decoder_out = Conv2DTranspose(1, 3, strides=1, activation='leaky_relu', padding='same')(y)
decoder = tf.keras.Model(decoder_in, decoder_out, name='decoder')

# Add VQ layer
vq_layer = VectorQuantizer(num_embeddings=num_embeddings, embedding_dim=latent_dim, name='vq')

inputs = Input(shape=input_shape)
encoder_outputs = encoder(inputs)
quantized_latents = vq_layer(encoder_outputs)
reconstructions = decoder(quantized_latents)
return tf.keras.Model(inputs, reconstructions, name='vq-vae')

def get_pixelcnn_sampler(pixelcnn):
inputs = Input(shape=pixelcnn.input_shape[1:])
outputs = pixelcnn(inputs, training=False)
categorical_layer = tfp.layers.DistributionLambda(tfp.distributions.Categorical)
outputs = categorical_layer(outputs)
return tf.keras.Model(inputs, outputs)
82 changes: 82 additions & 0 deletions recognition/45819061-VQVAE-OASIS/predict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from modules import get_pixelcnn_sampler
import tensorflow as tf
import numpy as np
from matplotlib import pyplot as plt

def show_subplot(original, reconstructed, i):
plt.subplot(1, 2, 1)
plt.imshow(original.squeeze() + 0.5)
plt.title("Original")
plt.axis("off")

plt.subplot(1, 2, 2)
plt.imshow(reconstructed.squeeze() + 0.5)
plt.title("Reconstructed")
plt.axis("off")
plt.savefig('fig'+str(i))
plt.close()

def demo_vqvae(model, x_test):
idx = np.random.choice(len(x_test), 10)
test_images = x_test[idx]
reconstructions_test = model.predict(test_images)

for i, (test_image, reconstructed_image) in enumerate(zip(test_images, reconstructions_test)):
show_subplot(test_image, reconstructed_image, i)

encoder = model.get_layer("encoder")
quantizer = model.get_layer("vector_quantizer")

encoded_outputs = encoder.predict(test_images)
flat_enc_outputs = encoded_outputs.reshape(-1, encoded_outputs.shape[-1])
codebook_indices = quantizer.get_code_indices(flat_enc_outputs)
codebook_indices = codebook_indices.numpy().reshape(encoded_outputs.shape[:-1])

for i in range(len(test_images)):
plt.subplot(1, 2, 1)
plt.imshow(test_images[i].squeeze() + 0.5)
plt.title("Original")
plt.axis("off")

plt.subplot(1, 2, 2)
plt.imshow(codebook_indices[i])
plt.title("Code")
plt.axis("off")
plt.savefig('embedding'+str(i))
plt.close()


def sample_images(vqvae, pixelcnn):
decoder = vqvae.get_layer('decoder')
quantizer = vqvae.get_layer('vector_quantizer')
sampler = get_pixelcnn_sampler(pixelcnn)

prior_batch_size = 10
priors = np.zeros(shape=(prior_batch_size,) + pixelcnn.input_shape[1:])
batch, rows, cols = priors.shape

for row in range(rows):
for col in range(cols):
probs = sampler.predict(priors, verbose=0)
priors[:, row, col] = probs[:, row, col]

pretrained_embeddings = quantizer.embeddings
prior_onehot = tf.one_hot(priors.astype("int32"), quantizer.num_embeddings).numpy()
quantized = tf.matmul(prior_onehot.astype("float32"), pretrained_embeddings, transpose_b=True)
quantized = tf.reshape(quantized, (-1, *(vqvae.get_layer('encoder').compute_output_shape((1, 256, 256, 1))[1:])))

# Generate novel images.
generated_samples = decoder.predict(quantized)

for i in range(batch):
plt.subplot(1, 2, 1)
plt.imshow(priors[i])
plt.title("Code")
plt.axis("off")

plt.subplot(1, 2, 2)
plt.imshow(generated_samples[i].squeeze() + 0.5)
plt.title("Generated Sample")
plt.axis("off")
plt.savefig('gen'+str(i))
plt.close()
Loading