diff --git a/gan/gan.py b/gan/gan.py index b7033d44f6..ba5b25cf9f 100644 --- a/gan/gan.py +++ b/gan/gan.py @@ -5,30 +5,68 @@ from keras.layers import BatchNormalization, Activation, ZeroPadding2D from keras.layers.advanced_activations import LeakyReLU from keras.layers.convolutional import UpSampling2D, Conv2D -from keras.models import Sequential, Model +from keras.models import Sequential, Model, model_from_json from keras.optimizers import Adam import matplotlib.pyplot as plt import sys +import os import numpy as np + class GAN(): - def __init__(self): + + def __init__(self, load_model=False): self.img_rows = 28 self.img_cols = 28 self.channels = 1 self.img_shape = (self.img_rows, self.img_cols, self.channels) self.latent_dim = 100 + self.model_save_dir = 'saved_model' + + if load_model: + self.load_model() + else: + self.init_model() + + def load_model(self): + optimizer = Adam(0.0002, 0.5) + + # Load and compile the discriminator + self.discriminator = self.load_keras_model("discriminator_model") + self.discriminator.compile(loss='binary_crossentropy', + optimizer=optimizer, + metrics=['accuracy']) + + # Load the generator + self.generator = self.load_keras_model("generator_model") + + # The generator takes noise as input and generates imgs + z = Input(shape=(self.latent_dim,)) + img = self.generator(z) + + # For the combined model we will only train the generator + self.discriminator.trainable = False + + # The discriminator takes generated images as input and determines validity + validity = self.discriminator(img) + + # The combined model (stacked generator and discriminator) + # Trains the generator to fool the discriminator + self.combined = self.load_keras_model("combined_model") + self.combined.compile(loss='binary_crossentropy', optimizer=optimizer) + + def init_model(self): optimizer = Adam(0.0002, 0.5) # Build and compile the discriminator self.discriminator = self.build_discriminator() self.discriminator.compile(loss='binary_crossentropy', - optimizer=optimizer, - metrics=['accuracy']) + optimizer=optimizer, + metrics=['accuracy']) # Build the generator self.generator = self.build_generator() @@ -43,12 +81,11 @@ def __init__(self): # The discriminator takes generated images as input and determines validity validity = self.discriminator(img) - # The combined model (stacked generator and discriminator) + # The combined model (stacked generator and discriminator) # Trains the generator to fool the discriminator self.combined = Model(z, validity) self.combined.compile(loss='binary_crossentropy', optimizer=optimizer) - def build_generator(self): model = Sequential() @@ -89,7 +126,7 @@ def build_discriminator(self): return Model(img, validity) - def train(self, epochs, batch_size=128, sample_interval=50): + def train(self, epochs, batch_size=128, sample_interval=50, model_save_interval=50): # Load the dataset (X_train, _), (_, _) = mnist.load_data() @@ -132,12 +169,16 @@ def train(self, epochs, batch_size=128, sample_interval=50): g_loss = self.combined.train_on_batch(noise, valid) # Plot the progress - print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss)) + print("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % + (epoch, d_loss[0], 100 * d_loss[1], g_loss)) # If at save interval => save generated image samples if epoch % sample_interval == 0: self.sample_images(epoch) + if epoch != 0 and epoch % model_save_interval == 0: + self.save_models() + def sample_images(self, epoch): r, c = 5, 5 noise = np.random.normal(0, 1, (r * c, self.latent_dim)) @@ -150,13 +191,40 @@ def sample_images(self, epoch): cnt = 0 for i in range(r): for j in range(c): - axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray') - axs[i,j].axis('off') + axs[i, j].imshow(gen_imgs[cnt, :, :, 0], cmap='gray') + axs[i, j].axis('off') cnt += 1 fig.savefig("images/%d.png" % epoch) plt.close() + def load_keras_model(self, model_name): + json_name = os.path.join(self.model_save_dir, model_name + ".json") + weights_name = os.path.join(self.model_save_dir, model_name + ".h5") + + with open(json_name, 'r') as json_file: + loaded_model_json = json_file.read() + + loaded_model = model_from_json(loaded_model_json) + loaded_model.load_weights(weights_name) + + return loaded_model + + def save_models(self): + self.save_model(self.discriminator, + os.path.join(self.model_save_dir, 'discriminator_model')) + self.save_model(self.generator, + os.path.join(self.model_save_dir, 'generator_model')) + self.save_model(self.combined, + os.path.join(self.model_save_dir, 'combined_model')) + + def save_model(self, model, model_path): + with open(str(model_path) + '.json', 'w') as json_file: + json_file.write(model.to_json()) + + model.save_weights(str(model_path + '.h5')) + if __name__ == '__main__': - gan = GAN() - gan.train(epochs=30000, batch_size=32, sample_interval=200) + gan = GAN(load_model=False) + gan.train(epochs=30001, batch_size=32, + sample_interval=200, model_save_interval=200)