Skip to content
Snippets Groups Projects
Commit 65484081 authored by JGlombitza's avatar JGlombitza
Browse files

plot_model

remove class
parent 1ce9c454
No related branches found
No related tags found
No related merge requests found
......@@ -28,18 +28,6 @@ N = shower_maps.shape[0]
utils.plot_multiple_signalmaps(shower_maps[:,:,:,0], log_dir=log_dir, title='Footprints', epoch='Real')
class RandomWeightedAverage(_Merge):
"""Takes a randomly-weighted average of two tensors. In geometric terms, this outputs a random point on the line
between each pair of input points.
Inheriting from _Merge is a little messy but it was the quickest solution I could think of.
Improvements appreciated."""
def _merge_function(self, inputs):
weights = K.random_uniform((BATCH_SIZE, 1, 1, 1))
return (weights * inputs[0]) + ((1 - weights) * inputs[1])
# build generator
def build_generator(latent_size):
inputs = Input(shape=(latent_size,))
......@@ -77,9 +65,9 @@ def build_critic():
generator = build_generator(latent_size)
#plot_model(generator, to_file=log_dir + '/generator.png', show_shapes=True)
print(generator.summary())
critic = build_critic()
#plot_model(critic, to_file=log_dir + '/critic.png', show_shapes=True)
print(critic.summary())
# make trainings model for generator
utils.make_trainable(critic, False)
......@@ -87,7 +75,7 @@ utils.make_trainable(generator, True)
generator_training = utils.build_generator_graph(generator, critic, latent_size)
generator_training.compile(optimizer=Adam(0.0001, beta_1=0.5, beta_2=0.9, decay=0.0), loss=[utils.wasserstein_loss])
#plot_model(generator_training, to_file=log_dir + '/generator_training.png', show_shapes=True)
plot_model(generator_training, to_file=log_dir + '/generator_training.png', show_shapes=True)
# make trainings model for critic
utils.make_trainable(critic, True)
......@@ -98,7 +86,7 @@ critic_training, averaged_batch = utils.build_critic_graph(generator, critic, la
gradient_penalty = partial(utils.gradient_penalty_loss, averaged_batch=averaged_batch, penalty_weight=GRADIENT_PENALTY_WEIGHT)
gradient_penalty.__name__ = 'gradient_penalty'
critic_training.compile(optimizer=Adam(0.0001, beta_1=0.5, beta_2=0.9, decay=0.0), loss=[utils.wasserstein_loss, utils.wasserstein_loss, gradient_penalty])
#plot_model(critic_training, to_file=log_dir + '/critic_training.png', show_shapes=True)
plot_model(critic_training, to_file=log_dir + '/critic_training.png', show_shapes=True)
# For Wassersteinloss
positive_y = np.ones(BATCH_SIZE)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment