Skip to content
Snippets Groups Projects
Commit 359dd7c4 authored by JGlombitza's avatar JGlombitza
Browse files

add model to MNIST.py

parent f1b2b366
No related branches found
No related tags found
No related merge requests found
import numpy as np import numpy as np
from keras.layers import Input from keras.layers import Input, Activation, BatchNormalization, Conv2D, Dense, Dropout, Flatten, Reshape, UpSampling2D
from keras.models import Model from keras.layers.advanced_activations import LeakyReLU
from keras.regularizers import l1_l2
from keras.models import Sequential, Model
from keras.optimizers import Adam from keras.optimizers import Adam
import mnist import mnist
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import keras.backend.tensorflow_backend as KTF import keras.backend.tensorflow_backend as KTF
from gan import build_generator, build_discriminator, plot_images, make_trainable, get_session from gan import plot_images, make_trainable, get_session
log_dir = "." log_dir = "."
...@@ -20,27 +22,50 @@ X_test = data.test_images.reshape(-1, 28, 28, 1) / 255. ...@@ -20,27 +22,50 @@ X_test = data.test_images.reshape(-1, 28, 28, 1) / 255.
idx = np.random.choice(len(X_train), 16) idx = np.random.choice(len(X_train), 16)
plot_images(X_train[idx], fname=log_dir + '/real_images.png') plot_images(X_train[idx], fname=log_dir + '/real_images.png')
# --------------------------------------------------
# Set up generator, discriminator and GAN (stacked generator + discriminator)
# Feel free to modify eg. :
# - the provided models (see gan.py)
# - the learning rate
# - the batchsize
# --------------------------------------------------
# Set up generator # Set up generator
print('\nGenerator') print('\nGenerator')
latent_dim = 100 latent_dim = 100
generator = build_generator(latent_dim) generator = Sequential(name='generator')
generator.add(Dense(7 * 7 * 128, input_shape=(latent_dim,)))
generator.add(BatchNormalization())
generator.add(Activation('relu'))
generator.add(Reshape([7, 7, 128]))
generator.add(UpSampling2D(size=(2, 2)))
generator.add(Conv2D(128, (5, 5), padding='same'))
generator.add(BatchNormalization())
generator.add(Activation('relu'))
generator.add(UpSampling2D(size=(2, 2)))
generator.add(Conv2D(64, (5, 5), padding='same'))
generator.add(BatchNormalization())
generator.add(Activation('relu'))
generator.add(Conv2D(1, (5, 5), padding='same', activation='sigmoid'))
print(generator.summary()) print(generator.summary())
# Set up discriminator # Set up discriminator
print('\nDiscriminator') print('\nDiscriminator')
discriminator = build_discriminator() drop_rate = 0.25
discriminator = Sequential(name='discriminator')
discriminator.add(Conv2D(32, (5, 5), padding='same', strides=(2, 2), activation='relu', input_shape=(28, 28, 1)))
discriminator.add(LeakyReLU(0.2))
discriminator.add(Dropout(drop_rate))
discriminator.add(Conv2D(64, (5, 5), padding='same', strides=(2, 2), activation='relu'))
discriminator.add(LeakyReLU(0.2))
discriminator.add(Dropout(drop_rate))
discriminator.add(Conv2D(128, (5, 5), padding='same', strides=(2, 2), activation='relu'))
discriminator.add(LeakyReLU(0.2))
discriminator.add(Dropout(drop_rate))
discriminator.add(Flatten())
discriminator.add(Dense(256, activity_regularizer=l1_l2(1e-5)))
discriminator.add(LeakyReLU(0.2))
discriminator.add(Dropout(drop_rate))
discriminator.add(Dense(2, activation='softmax'))
print(discriminator.summary()) print(discriminator.summary())
d_opt = Adam(lr=2e-4, beta_1=0.5, decay=0.0005) d_opt = Adam(lr=2e-4, beta_1=0.5, decay=0.0005)
discriminator.compile(loss='binary_crossentropy', optimizer=d_opt, metrics=['accuracy']) discriminator.compile(loss='binary_crossentropy', optimizer=d_opt, metrics=['accuracy'])
# Set up GAN by stacking the discriminator on top of the generator # Set up GAN by stacking the discriminator on top of the generator
print('\nGenerative Adversarial Network') print('\nGenerative Adversarial Network')
gan_input = Input(shape=[latent_dim]) gan_input = Input(shape=[latent_dim])
...@@ -52,10 +77,8 @@ make_trainable(discriminator, False) # freezes the discriminator when training ...@@ -52,10 +77,8 @@ make_trainable(discriminator, False) # freezes the discriminator when training
GAN.compile(loss='binary_crossentropy', optimizer=g_opt) GAN.compile(loss='binary_crossentropy', optimizer=g_opt)
# Compile saves the trainable status of the model --> After the model is compiled, updating using make_trainable will have no effect # Compile saves the trainable status of the model --> After the model is compiled, updating using make_trainable will have no effect
# --------------------------------------------------
# Pretrain the discriminator:
# --------------------------------------------------
# Pretrain the discriminator:
# - Create a dataset of 10000 real train images and 10000 fake images. # - Create a dataset of 10000 real train images and 10000 fake images.
ntrain = 10000 ntrain = 10000
no = np.random.choice(60000, size=ntrain, replace='False') no = np.random.choice(60000, size=ntrain, replace='False')
......
import numpy as np import numpy as np
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from keras.layers import Activation, BatchNormalization, Conv2D, Dense, Dropout, Flatten, Reshape, UpSampling2D
from keras.layers.advanced_activations import LeakyReLU
from keras.regularizers import l1_l2
from keras.models import Sequential
import tensorflow as tf import tensorflow as tf
...@@ -34,41 +30,3 @@ def make_trainable(model, trainable): ...@@ -34,41 +30,3 @@ def make_trainable(model, trainable):
model.trainable = trainable model.trainable = trainable
for l in model.layers: for l in model.layers:
l.trainable = trainable l.trainable = trainable
def build_generator(latent_dim):
generator = Sequential(name='generator')
generator.add(Dense(7 * 7 * 128, input_shape=(latent_dim,)))
generator.add(BatchNormalization())
generator.add(Activation('relu'))
generator.add(Reshape([7, 7, 128]))
generator.add(UpSampling2D(size=(2, 2)))
generator.add(Conv2D(128, (5, 5), padding='same'))
generator.add(BatchNormalization())
generator.add(Activation('relu'))
generator.add(UpSampling2D(size=(2, 2)))
generator.add(Conv2D(64, (5, 5), padding='same'))
generator.add(BatchNormalization())
generator.add(Activation('relu'))
generator.add(Conv2D(1, (5, 5), padding='same', activation='sigmoid'))
return generator
def build_discriminator(drop_rate=0.25):
""" Discriminator network """
discriminator = Sequential(name='discriminator')
discriminator.add(Conv2D(32, (5, 5), padding='same', strides=(2, 2), activation='relu', input_shape=(28, 28, 1)))
discriminator.add(LeakyReLU(0.2))
discriminator.add(Dropout(drop_rate))
discriminator.add(Conv2D(64, (5, 5), padding='same', strides=(2, 2), activation='relu'))
discriminator.add(LeakyReLU(0.2))
discriminator.add(Dropout(drop_rate))
discriminator.add(Conv2D(128, (5, 5), padding='same', strides=(2, 2), activation='relu'))
discriminator.add(LeakyReLU(0.2))
discriminator.add(Dropout(drop_rate))
discriminator.add(Flatten())
discriminator.add(Dense(256, activity_regularizer=l1_l2(1e-5)))
discriminator.add(LeakyReLU(0.2))
discriminator.add(Dropout(drop_rate))
discriminator.add(Dense(2, activation='softmax'))
return discriminator
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment