Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added functionality to save/load Keras model for intermittent training. #117

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 80 additions & 12 deletions gan/gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,30 +5,68 @@
from keras.layers import BatchNormalization, Activation, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.models import Sequential, Model, model_from_json
from keras.optimizers import Adam

import matplotlib.pyplot as plt

import sys
import os

import numpy as np


class GAN():
def __init__(self):

def __init__(self, load_model=False):
self.img_rows = 28
self.img_cols = 28
self.channels = 1
self.img_shape = (self.img_rows, self.img_cols, self.channels)
self.latent_dim = 100

self.model_save_dir = 'saved_model'

if load_model:
self.load_model()
else:
self.init_model()

def load_model(self):
optimizer = Adam(0.0002, 0.5)

# Load and compile the discriminator
self.discriminator = self.load_keras_model("discriminator_model")
self.discriminator.compile(loss='binary_crossentropy',
optimizer=optimizer,
metrics=['accuracy'])

# Load the generator
self.generator = self.load_keras_model("generator_model")

# The generator takes noise as input and generates imgs
z = Input(shape=(self.latent_dim,))
img = self.generator(z)

# For the combined model we will only train the generator
self.discriminator.trainable = False

# The discriminator takes generated images as input and determines validity
validity = self.discriminator(img)

# The combined model (stacked generator and discriminator)
# Trains the generator to fool the discriminator
self.combined = self.load_keras_model("combined_model")
self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)

def init_model(self):
optimizer = Adam(0.0002, 0.5)

# Build and compile the discriminator
self.discriminator = self.build_discriminator()
self.discriminator.compile(loss='binary_crossentropy',
optimizer=optimizer,
metrics=['accuracy'])
optimizer=optimizer,
metrics=['accuracy'])

# Build the generator
self.generator = self.build_generator()
Expand All @@ -43,12 +81,11 @@ def __init__(self):
# The discriminator takes generated images as input and determines validity
validity = self.discriminator(img)

# The combined model (stacked generator and discriminator)
# The combined model (stacked generator and discriminator)
# Trains the generator to fool the discriminator
self.combined = Model(z, validity)
self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)


def build_generator(self):

model = Sequential()
Expand Down Expand Up @@ -89,7 +126,7 @@ def build_discriminator(self):

return Model(img, validity)

def train(self, epochs, batch_size=128, sample_interval=50):
def train(self, epochs, batch_size=128, sample_interval=50, model_save_interval=50):

# Load the dataset
(X_train, _), (_, _) = mnist.load_data()
Expand Down Expand Up @@ -132,12 +169,16 @@ def train(self, epochs, batch_size=128, sample_interval=50):
g_loss = self.combined.train_on_batch(noise, valid)

# Plot the progress
print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))
print("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" %
(epoch, d_loss[0], 100 * d_loss[1], g_loss))

# If at save interval => save generated image samples
if epoch % sample_interval == 0:
self.sample_images(epoch)

if epoch != 0 and epoch % model_save_interval == 0:
self.save_models()

def sample_images(self, epoch):
r, c = 5, 5
noise = np.random.normal(0, 1, (r * c, self.latent_dim))
Expand All @@ -150,13 +191,40 @@ def sample_images(self, epoch):
cnt = 0
for i in range(r):
for j in range(c):
axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')
axs[i,j].axis('off')
axs[i, j].imshow(gen_imgs[cnt, :, :, 0], cmap='gray')
axs[i, j].axis('off')
cnt += 1
fig.savefig("images/%d.png" % epoch)
plt.close()

def load_keras_model(self, model_name):
json_name = os.path.join(self.model_save_dir, model_name + ".json")
weights_name = os.path.join(self.model_save_dir, model_name + ".h5")

with open(json_name, 'r') as json_file:
loaded_model_json = json_file.read()

loaded_model = model_from_json(loaded_model_json)
loaded_model.load_weights(weights_name)

return loaded_model

def save_models(self):
self.save_model(self.discriminator,
os.path.join(self.model_save_dir, 'discriminator_model'))
self.save_model(self.generator,
os.path.join(self.model_save_dir, 'generator_model'))
self.save_model(self.combined,
os.path.join(self.model_save_dir, 'combined_model'))

def save_model(self, model, model_path):
with open(str(model_path) + '.json', 'w') as json_file:
json_file.write(model.to_json())

model.save_weights(str(model_path + '.h5'))


if __name__ == '__main__':
gan = GAN()
gan.train(epochs=30000, batch_size=32, sample_interval=200)
gan = GAN(load_model=False)
gan.train(epochs=30001, batch_size=32,
sample_interval=200, model_save_interval=200)