Skip to content
Snippets Groups Projects
Commit 974e6029 authored by Lukas Geiger's avatar Lukas Geiger
Browse files

Fix linting errors

parent 1c7ca8e6
No related branches found
No related tags found
No related merge requests found
import numpy as np import numpy as np
from tensorflow import keras
from keras.layers import * from keras.layers import *
from keras.models import Sequential, Model from keras.models import Sequential
from keras.optimizers import Adam from keras.optimizers import Adam
from keras.layers.advanced_activations import LeakyReLU from keras.layers.advanced_activations import LeakyReLU
from functools import partial from functools import partial
from keras.utils import plot_model from keras.utils import plot_model
import keras.backend.tensorflow_backend as KTF import keras.backend.tensorflow_backend as KTF
import utils import utils
import tensorflow as tf
KTF.set_session(utils.get_session()) # Allows 2 jobs per GPU, Please do not change this during the tutorial KTF.set_session(utils.get_session()) # Allows 2 jobs per GPU, Please do not change this during the tutorial
...@@ -100,7 +98,7 @@ critic_loss = [] ...@@ -100,7 +98,7 @@ critic_loss = []
# trainings loop # trainings loop
iterations_per_epoch = N//((NCR+1)*BATCH_SIZE) iterations_per_epoch = N//((NCR+1)*BATCH_SIZE)
for epoch in range(EPOCHS): for epoch in range(EPOCHS):
print "epoch: ", epoch print("epoch: ", epoch)
# plot berfore each epoch generated signal patterns # plot berfore each epoch generated signal patterns
generated_map = generator.predict_on_batch(np.random.randn(BATCH_SIZE, latent_size)) generated_map = generator.predict_on_batch(np.random.randn(BATCH_SIZE, latent_size))
utils.plot_multiple_signalmaps(generated_map[:, :, :, 0], log_dir=log_dir, title='Generated Footprints Epoch: ', epoch=str(epoch)) utils.plot_multiple_signalmaps(generated_map[:, :, :, 0], log_dir=log_dir, title='Generated Footprints Epoch: ', epoch=str(epoch))
...@@ -111,11 +109,11 @@ for epoch in range(EPOCHS): ...@@ -111,11 +109,11 @@ for epoch in range(EPOCHS):
noise_batch = np.random.randn(BATCH_SIZE, latent_size) # generate noise batch for generator noise_batch = np.random.randn(BATCH_SIZE, latent_size) # generate noise batch for generator
shower_batch = shower_maps[BATCH_SIZE*(j+iteration):BATCH_SIZE*(j++iteration+1)] # take batch of shower maps shower_batch = shower_maps[BATCH_SIZE*(j+iteration):BATCH_SIZE*(j++iteration+1)] # take batch of shower maps
critic_loss.append(critic_training.train_on_batch([noise_batch, shower_batch], [negative_y, positive_y, dummy])) # train the critic critic_loss.append(critic_training.train_on_batch([noise_batch, shower_batch], [negative_y, positive_y, dummy])) # train the critic
print "critic loss:", critic_loss[-1] print("critic loss:", critic_loss[-1])
noise_batch = np.random.randn(BATCH_SIZE, latent_size) # generate noise batch for generator noise_batch = np.random.randn(BATCH_SIZE, latent_size) # generate noise batch for generator
generator_loss.append(generator_training.train_on_batch([noise_batch], [positive_y])) # train the generator generator_loss.append(generator_training.train_on_batch([noise_batch], [positive_y])) # train the generator
print "generator loss:", generator_loss[-1] print("generator loss:", generator_loss[-1])
# plot critic and generator loss # plot critic and generator loss
utils.plot_loss(critic_loss, name="critic", log_dir=log_dir) utils.plot_loss(critic_loss, name="critic", log_dir=log_dir)
......
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from tensorflow import keras
from keras.layers import * from keras.layers import *
from keras.models import Sequential, Model from keras.models import Model
from keras.layers.merge import _Merge from keras.layers.merge import _Merge
from keras import backend as K
import matplotlib import matplotlib
matplotlib.use('Agg') matplotlib.use('Agg')
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
...@@ -111,7 +111,7 @@ def plot_signal_map(footprint, axis, label=None): ...@@ -111,7 +111,7 @@ def plot_signal_map(footprint, axis, label=None):
cbar = plt.colorbar(circles, ax=axis) cbar = plt.colorbar(circles, ax=axis)
cbar.set_label('signal [a.u.]') cbar.set_label('signal [a.u.]')
axis.grid(True) axis.grid(True)
if label!=None: if label is not None:
axis.text(0.95, 0.1, "Energy: %.1f EeV" % label, verticalalignment='top', horizontalalignment='right', transform=axis.transAxes, backgroundcolor='w') axis.text(0.95, 0.1, "Energy: %.1f EeV" % label, verticalalignment='top', horizontalalignment='right', transform=axis.transAxes, backgroundcolor='w')
axis.set_aspect('equal') axis.set_aspect('equal')
axis.set_xlim(-5, 5) axis.set_xlim(-5, 5)
......
...@@ -89,9 +89,9 @@ print('Test accuracy: %04f' % pretrain_acc) ...@@ -89,9 +89,9 @@ print('Test accuracy: %04f' % pretrain_acc)
losses = {"d": [], "g": []} losses = {"d": [], "g": []}
discriminator_acc = [] discriminator_acc = []
# main training loop # main training loop
def train_for_n(epochs=1, batch_size=32): def train_for_n(epochs=1, batch_size=32):
for epoch in range(epochs): for epoch in range(epochs):
# Plot some fake images # Plot some fake images
...@@ -127,6 +127,7 @@ def train_for_n(epochs=1, batch_size=32): ...@@ -127,6 +127,7 @@ def train_for_n(epochs=1, batch_size=32):
g_loss = GAN.train_on_batch(noise_tr, y2) g_loss = GAN.train_on_batch(noise_tr, y2)
losses["g"].append(g_loss) losses["g"].append(g_loss)
train_for_n(epochs=10, batch_size=128) train_for_n(epochs=10, batch_size=128)
# - Plot the loss of discriminator and generator as function of iterations # - Plot the loss of discriminator and generator as function of iterations
......
...@@ -10,7 +10,7 @@ import numpy as np ...@@ -10,7 +10,7 @@ import numpy as np
import gzip import gzip
import os import os
import json import json
import matplotlib.pyplot as plt
def get_datapath(fname=''): def get_datapath(fname=''):
"""Get data path. """Get data path.
...@@ -36,15 +36,6 @@ class Dataset(): ...@@ -36,15 +36,6 @@ class Dataset():
""" """
self.__dict__.update(kwds) self.__dict__.update(kwds)
def plot_examples(self, num_examples=5, fname=None):
"""Plot examples from the dataset
Args:
num_examples (int, optional): number of examples to
fname (str, optional): filename for saving the plot
"""
plot_examples(self, num_examples, fname)
def load_data(): def load_data():
"""Load the MNIST dataset. """Load the MNIST dataset.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment