Skip to content
Snippets Groups Projects
Commit 88b4a56c authored by JGlombitza's avatar JGlombitza
Browse files

change data path

parent 76756103
No related branches found
No related tags found
No related merge requests found
...@@ -8,7 +8,6 @@ from keras.layers.advanced_activations import LeakyReLU ...@@ -8,7 +8,6 @@ from keras.layers.advanced_activations import LeakyReLU
from functools import partial from functools import partial
from keras.utils import plot_model from keras.utils import plot_model
import keras.backend.tensorflow_backend as KTF import keras.backend.tensorflow_backend as KTF
import glob
import utils import utils
import tensorflow as tf import tensorflow as tf
...@@ -16,22 +15,19 @@ import tensorflow as tf ...@@ -16,22 +15,19 @@ import tensorflow as tf
KTF.set_session(utils.get_session()) # Allows 2 jobs per GPU, Please do not change this during the tutorial KTF.set_session(utils.get_session()) # Allows 2 jobs per GPU, Please do not change this during the tutorial
log_dir="." log_dir="."
EPOCHS = 3 EPOCHS = 10
GRADIENT_PENALTY_WEIGHT = 10 GRADIENT_PENALTY_WEIGHT = 10
BATCH_SIZE = 256 BATCH_SIZE = 256
NCR = 5 NCR = 5
latent_size = 512 latent_size = 512
# load trainings data
filenames=glob.glob("*.npz")
shower_maps, Energy = utils.ReadInData(filenames) # load trainings data
#shower_maps = shower_maps[:,:,:,1,np.newaxis] shower_maps, Energy = utils.ReadInData()
#np.savez("Data", Energy=Energy, shower_maps=shower_maps) N = shower_maps.shape[0]
Energy = Energy/np.max(Energy) # plot real signal patterns
utils.plot_multiple_signalmaps(shower_maps[:,:,:,0], log_dir=log_dir, title='Footprints', epoch='Real') utils.plot_multiple_signalmaps(shower_maps[:,:,:,0], log_dir=log_dir, title='Footprints', epoch='Real')
N = shower_maps.shape[0]
class RandomWeightedAverage(_Merge): class RandomWeightedAverage(_Merge):
"""Takes a randomly-weighted average of two tensors. In geometric terms, this outputs a random point on the line """Takes a randomly-weighted average of two tensors. In geometric terms, this outputs a random point on the line
...@@ -116,6 +112,7 @@ critic_loss = [] ...@@ -116,6 +112,7 @@ critic_loss = []
iterations_per_epoch = N//((NCR+1)*BATCH_SIZE) iterations_per_epoch = N//((NCR+1)*BATCH_SIZE)
for epoch in range(EPOCHS): for epoch in range(EPOCHS):
print "epoch: ", epoch print "epoch: ", epoch
# plot berfore each epoch generated signal patterns
generated_map = generator.predict_on_batch(np.random.randn(BATCH_SIZE, latent_size)) generated_map = generator.predict_on_batch(np.random.randn(BATCH_SIZE, latent_size))
utils.plot_multiple_signalmaps(generated_map[:,:,:,0], log_dir=log_dir, title='Generated Footprints Epoch: ', epoch=str(epoch)) utils.plot_multiple_signalmaps(generated_map[:,:,:,0], log_dir=log_dir, title='Generated Footprints Epoch: ', epoch=str(epoch))
for iteration in range(iterations_per_epoch): for iteration in range(iterations_per_epoch):
...@@ -131,6 +128,6 @@ for epoch in range(EPOCHS): ...@@ -131,6 +128,6 @@ for epoch in range(EPOCHS):
utils.plot_loss(critic_loss, name="critic", log_dir=log_dir) utils.plot_loss(critic_loss, name="critic", log_dir=log_dir)
utils.plot_loss(generator_loss, name="generator",log_dir=log_dir) utils.plot_loss(generator_loss, name="generator",log_dir=log_dir)
# plot some generated figures # plot some generated signal patterns
generated_map = generator.predict(np.random.randn(BATCH_SIZE, latent_size)) generated_map = generator.predict(np.random.randn(BATCH_SIZE, latent_size))
utils.plot_multiple_signalmaps(generated_map[:,:,:,0], log_dir=log_dir, title='Generated Footprints', epoch='Final') utils.plot_multiple_signalmaps(generated_map[:,:,:,0], log_dir=log_dir, title='Generated Footprints', epoch='Final')
...@@ -15,17 +15,11 @@ def get_session(gpu_fraction=0.40): ...@@ -15,17 +15,11 @@ def get_session(gpu_fraction=0.40):
return tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) return tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
def ReadInData(filenames): def ReadInData():
'''Reads in the trainings data''' '''Reads in the trainings data'''
N = 100000 *len(filenames) filenames="/net/scratch/JGlombitza/Public/HAPWorkshop2018/data/Data.npz"
a = 100000 data = np.load(filenames)
shower_maps = np.zeros(N*9*9*1).reshape(N,9,9,1) return data['shower_maps'], data['Energy']
Energy = np.zeros(N)
for i in range(0, len(filenames)):
data = np.load(filenames[i])
Energy[a*i:a*(i+1)] = data['Energy']
shower_maps[a*i:a*(i+1)] = data['shower_maps'].reshape(a,9,9,1)
return shower_maps, Energy
def make_trainable(model, trainable): def make_trainable(model, trainable):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment