Skip to content
Snippets Groups Projects
Commit 6fb856b5 authored by JGlombitza's avatar JGlombitza
Browse files

add new explanations

parent c13ebed6
No related branches found
No related tags found
No related merge requests found
...@@ -14,6 +14,7 @@ import tensorflow as tf ...@@ -14,6 +14,7 @@ import tensorflow as tf
KTF.set_session(utils.get_session()) # Allows 2 jobs per GPU, Please do not change this during the tutorial KTF.set_session(utils.get_session()) # Allows 2 jobs per GPU, Please do not change this during the tutorial
log_dir="." log_dir="."
# basic trainings parameter
EPOCHS = 10 EPOCHS = 10
GRADIENT_PENALTY_WEIGHT = 10 GRADIENT_PENALTY_WEIGHT = 10
BATCH_SIZE = 256 BATCH_SIZE = 256
...@@ -45,6 +46,7 @@ def build_generator(latent_size): ...@@ -45,6 +46,7 @@ def build_generator(latent_size):
generator.add(Conv2D(1, (3, 3), padding='same', kernel_initializer='he_normal', activation='relu')) generator.add(Conv2D(1, (3, 3), padding='same', kernel_initializer='he_normal', activation='relu'))
return generator return generator
# build critic # build critic
# Feel free to modify the critic model # Feel free to modify the critic model
def build_critic(): def build_critic():
...@@ -70,20 +72,19 @@ critic = build_critic() ...@@ -70,20 +72,19 @@ critic = build_critic()
print(critic.summary()) print(critic.summary())
# make trainings model for generator # make trainings model for generator
utils.make_trainable(critic, False) utils.make_trainable(critic, False) # freeze the critic during the generator training
utils.make_trainable(generator, True) utils.make_trainable(generator, True) # unfreeze the generator during the generator training
generator_training = utils.build_generator_graph(generator, critic, latent_size) generator_training = utils.build_generator_graph(generator, critic, latent_size)
generator_training.compile(optimizer=Adam(0.0001, beta_1=0.5, beta_2=0.9, decay=0.0), loss=[utils.wasserstein_loss]) generator_training.compile(optimizer=Adam(0.0001, beta_1=0.5, beta_2=0.9, decay=0.0), loss=[utils.wasserstein_loss])
plot_model(generator_training, to_file=log_dir + '/generator_training.png', show_shapes=True) plot_model(generator_training, to_file=log_dir + '/generator_training.png', show_shapes=True)
# make trainings model for critic # make trainings model for critic
utils.make_trainable(critic, True) utils.make_trainable(critic, True) # unfreeze the critic during the critic training
utils.make_trainable(generator, False) utils.make_trainable(generator, False) # freeze the generator during the critic training
critic_training, averaged_batch = utils.build_critic_graph(generator, critic, latent_size, batch_size=BATCH_SIZE) critic_training, averaged_batch = utils.build_critic_graph(generator, critic, latent_size, batch_size=BATCH_SIZE)
gradient_penalty = partial(utils.gradient_penalty_loss, averaged_batch=averaged_batch, penalty_weight=GRADIENT_PENALTY_WEIGHT) gradient_penalty = partial(utils.gradient_penalty_loss, averaged_batch=averaged_batch, penalty_weight=GRADIENT_PENALTY_WEIGHT) # construct the gradient penalty
gradient_penalty.__name__ = 'gradient_penalty' gradient_penalty.__name__ = 'gradient_penalty'
critic_training.compile(optimizer=Adam(0.0001, beta_1=0.5, beta_2=0.9, decay=0.0), loss=[utils.wasserstein_loss, utils.wasserstein_loss, gradient_penalty]) critic_training.compile(optimizer=Adam(0.0001, beta_1=0.5, beta_2=0.9, decay=0.0), loss=[utils.wasserstein_loss, utils.wasserstein_loss, gradient_penalty])
plot_model(critic_training, to_file=log_dir + '/critic_training.png', show_shapes=True) plot_model(critic_training, to_file=log_dir + '/critic_training.png', show_shapes=True)
...@@ -104,15 +105,19 @@ for epoch in range(EPOCHS): ...@@ -104,15 +105,19 @@ for epoch in range(EPOCHS):
generated_map = generator.predict_on_batch(np.random.randn(BATCH_SIZE, latent_size)) generated_map = generator.predict_on_batch(np.random.randn(BATCH_SIZE, latent_size))
utils.plot_multiple_signalmaps(generated_map[:,:,:,0], log_dir=log_dir, title='Generated Footprints Epoch: ', epoch=str(epoch)) utils.plot_multiple_signalmaps(generated_map[:,:,:,0], log_dir=log_dir, title='Generated Footprints Epoch: ', epoch=str(epoch))
for iteration in range(iterations_per_epoch): for iteration in range(iterations_per_epoch):
for j in range(NCR): for j in range(NCR):
noise_batch = np.random.randn(BATCH_SIZE, latent_size) # generate noise batch for generator noise_batch = np.random.randn(BATCH_SIZE, latent_size) # generate noise batch for generator
shower_batch = shower_maps[BATCH_SIZE*(j+iteration):BATCH_SIZE*(j++iteration+1)] shower_batch = shower_maps[BATCH_SIZE*(j+iteration):BATCH_SIZE*(j++iteration+1)] # take batch of shower maps
critic_loss.append(critic_training.train_on_batch([noise_batch, shower_batch], [negative_y, positive_y, dummy])) critic_loss.append(critic_training.train_on_batch([noise_batch, shower_batch], [negative_y, positive_y, dummy])) # train the critic
print "critic loss:", critic_loss[-1] print "critic loss:", critic_loss[-1]
noise_batch = np.random.randn(BATCH_SIZE, latent_size) # generate noise batch for generator noise_batch = np.random.randn(BATCH_SIZE, latent_size) # generate noise batch for generator
generator_loss.append(generator_training.train_on_batch([noise_batch], [positive_y])) generator_loss.append(generator_training.train_on_batch([noise_batch], [positive_y])) # train the generator
print "generator loss:", generator_loss[-1] print "generator loss:", generator_loss[-1]
# plot critic and generator loss
utils.plot_loss(critic_loss, name="critic", log_dir=log_dir) utils.plot_loss(critic_loss, name="critic", log_dir=log_dir)
utils.plot_loss(generator_loss, name="generator",log_dir=log_dir) utils.plot_loss(generator_loss, name="generator",log_dir=log_dir)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment