Skip to content
Snippets Groups Projects
Select Git revision
  • master
1 result

ShowerGAN_signal+time.py

Blame
  • JGlombitza's avatar
    JGlombitza authored
    ShowerGAN Signal+Time added
    new correlation plot
    0e3f9f6b
    History
    Code owners
    Assign users and groups as approvers for specific file changes. Learn more.
    ShowerGAN_signal+time.py 9.82 KiB
    import numpy as np
    from tensorflow import keras
    from keras.layers import *
    from keras.layers.merge import _Merge
    from keras.models import Sequential, Model
    from keras.optimizers import Adam
    from keras.layers.advanced_activations import LeakyReLU
    from functools import partial
    from keras.utils import plot_model
    import glob
    import utils
    import tensorflow as tf
    import kerasAppendix as ka
    
    log_dir = ka.CookbookInit()
    
    EPOCHS = 20
    GRADIENT_PENALTY_WEIGHT = 10
    BATCH_SIZE = 64
    NCR = 5
    latent_size = 512
    # load trainings data
    Folder = '/home/JGlombitza/DataLink/ToyMC/PreProData' # Training no black tanks
    filenames=glob.glob(Folder + "/PlanarWaveShower_PrePro_2*")
    
    shower_maps, Energy = utils.ReadInData(filenames)
    utils.plot_showermaps(shower_maps[0,:,:,:], epoch='Real', log_dir=log_dir)
    utils.plot_correlation(shower_maps[:2500], log_dir=log_dir, name='Real')
    #utils.plot_array_scatter(shower_maps[0,:,:,0], label = 'time [a.u.]', fname='real_time', log_dir=log_dir)
    #utils.plot_array_scatter(shower_maps[0,:,:,1], label = 'signal [a.u.]', fname='real_signal', log_dir=log_dir)
    
    
    #shower_maps = shower_maps[:,:,:,1,np.newaxis]
    
    N = shower_maps.shape[0]
    
    class RandomWeightedAverage(_Merge):
        """Takes a randomly-weighted average of two tensors. In geometric terms, this outputs a random point on the line
        between each pair of input points.
        Inheriting from _Merge is a little messy but it was the quickest solution I could think of.
        Improvements appreciated."""
    
        def _merge_function(self, inputs):
            weights = K.random_uniform((BATCH_SIZE, 1, 1, 1))
            return (weights * inputs[0]) + ((1 - weights) * inputs[1])
    
    #def split_layer(value):
    #    return Lambda(lambda y : value[:,:,:,64:])(value), Lambda(lambda x : value[:,:,:,:64])(value)
    
    #def split_layer(value, num, axis=-1):
    #    return tf.split(value, num, axis=axis)
    
    #def split_layer_output_shape(input_shape):
    #    print input_shape
    #    shape = list(input_shape)
    #    print input_shape
    #    shape[-1] = shape[-1]//2
    #    shape[-1] *= 2
    #    return tuple((shape, shape))
    
    def ResnetBlock(inp, nfilters, kernel=(3,3)):
        inp = Conv2D(nfilters, (1, 1), padding='same', kernel_initializer='he_normal', activation='relu')(inp)
        conv = Conv2D(nfilters, (3, 3), padding='same', kernel_initializer='he_normal', activation='relu')(inp)
        conv = BatchNormalization()(conv)
        conv = Conv2D(nfilters, (3, 3), padding='same', kernel_initializer='he_normal', activation='relu')(conv)
        conv = BatchNormalization()(conv)
        z = Add()([conv,inp])
        return Activation('relu')(z)
    
    def output_of_lambda(input_shape):
        return (input_shape[0])
    
    def multi(inputs):
        x,y = inputs
        y = tf.cast(y, tf.bool)
        y = tf.cast(y, tf.float32)
        return Multiply()([x,y])
    
    # build generator
    def build_generator(latent_size):
        inputs = Input(shape=(latent_size,))    
        x = Dense(latent_size, activation='relu')(inputs)
        x = Reshape((1,1,latent_size))(x)
        x = UpSampling2D(size=(3,3))(x)
        x = Conv2D(128, (2, 2), padding='same', kernel_initializer='he_normal', activation='relu')(x)
        x = BatchNormalization()(x)
        x = UpSampling2D(size=(3,3))(x)
        x = Conv2D(64, (3, 3), padding='same', kernel_initializer='he_normal', activation='relu')(x)
        x = BatchNormalization()(x)
        x = Conv2D(8, (3, 3), padding='same', kernel_initializer='he_normal', activation='relu')(x)
        x = BatchNormalization()(x)
    
        signal = ResnetBlock(x, 64)
        signal = ResnetBlock(signal, 32)
    #    signal = ResnetBlock(signal, 32)
    #    signal = ResnetBlock(signal, 32)
        signal = Conv2D(1, (3, 3), padding='same', kernel_initializer='he_normal', activation='relu')(signal)
    
        time = ResnetBlock(x, 64)
        time = ResnetBlock(time, 32)
    #    time = ResnetBlock(time, 32)
    #    time = ResnetBlock(time, 32)
        time = Conv2D(1, (3, 3), padding='same', kernel_initializer='he_normal', activation='relu', bias_initializer='ones')(time)
    #    filter = Lambda(lambda x: tf.cast(signal, tf.bool))(signal)
    #    filter = Lambda(lambda x: tf.cast(signal, tf.float32))(filter)
        time = Lambda(multi, output_shape = output_of_lambda)([time, signal])
        # tower to produce time footprint
    #    time = Conv2D(128, (3, 3), padding='same', kernel_initializer='he_normal', activation='relu')(x)
    #    time = BatchNormalization()(time)
    #    time = Conv2D(64, (3, 3), padding='same', kernel_initializer='he_normal', activation='relu')(time)
    #    time = BatchNormalization()(time)
    #    time = Conv2D(1, (3, 3), padding='same', kernel_initializer='he_normal', activation='relu', bias_initializer='ones')(time)
    
    #    # tower to produce signal footprint
    #    signal = Conv2D(128, (3, 3), padding='same', kernel_initializer='he_normal', activation='relu')(x)
    #    signal = BatchNormalization()(signal)
    #    signal = Conv2D(64, (3, 3), padding='same', kernel_initializer='he_normal', activation='relu')(signal)
    #    signal = BatchNormalization()(signal)
    #    signal = Conv2D(1, (3, 3), padding='same', kernel_initializer='he_normal', activation='relu')(signal)
    
        # concatenate both towers
        outputs = Concatenate(axis=-1)([time,signal])
        return Model(inputs=inputs, outputs=outputs, name='generator')
    
    
    def build_critic():
        inputs = Input(shape=(9,9,2))
        nfilter = 16
        x = Conv2D(nfilter, (3, 3), padding = 'same', activation='relu', kernel_initializer='he_normal')(inputs)
        for i in range(4):
            x = utils.DenselyConnectedConv(x, nfilter)
            nfilter = 2*nfilter
            x = Conv2D(nfilter, (3, 3), padding = 'same', activation='relu', kernel_initializer='he_normal')(x)
        x = Flatten()(x)
        outputs = Dense(1)(x)
        return Model(inputs=inputs, outputs=outputs, name='critic')
    
    #    critic = Sequential(name='critic')
    #    critic.add(Conv2D(64, (3, 3), padding='same', kernel_initializer='he_normal', input_shape=(9,9,2)))
    #    critic.add(LeakyReLU())
    #    critic.add(Conv2D(128, (3, 3), padding='same', kernel_initializer='he_normal'))
    #    critic.add(LeakyReLU())
    #    critic.add(Conv2D(128, (3, 3), padding='same', kernel_initializer='he_normal'))
    #    critic.add(LeakyReLU())
    #    critic.add(Conv2D(256, (3, 3), padding='same', kernel_initializer='he_normal'))
    #    critic.add(LeakyReLU())
    #    critic.add(GlobalMaxPooling2D())
    #    critic.add(Dense(100))
    #    critic.add(LeakyReLU())
    #    critic.add(Dense(1))
    #    return critic
    
    
    generator = build_generator(latent_size)
    plot_model(generator, to_file=log_dir + '/generator.png', show_shapes=True)
    critic = build_critic()
    plot_model(critic, to_file=log_dir + '/critic.png', show_shapes=True)
    
    # make trainings model for generator
    utils.make_trainable(critic, False)
    utils.make_trainable(generator, True)
    
    generator_in = Input(shape=(latent_size,))
    generator_out = generator(generator_in)
    critic_out = critic(generator_out)
    generator_training = Model(inputs=generator_in, outputs=critic_out)
    generator_training.compile(optimizer=Adam(0.0001, beta_1=0.5, beta_2=0.9, decay=0.0), loss=[utils.wasserstein_loss])
    plot_model(generator_training, to_file=log_dir + '/generator_training.png', show_shapes=True)
    
    # make trainings model for critic
    utils.make_trainable(critic, True)
    utils.make_trainable(generator, False)
    
    generator_in_critic_training = Input(shape=(latent_size,), name="noise")
    shower_in_critic_training = Input(shape=(9,9,2), name='shower_maps')
    generator_out_critic_training = generator(generator_in_critic_training)
    out_critic_training_gen = critic(generator_out_critic_training)
    out_critic_training_shower = critic(shower_in_critic_training)
    averaged_batch = RandomWeightedAverage(name='Average')([generator_out_critic_training, shower_in_critic_training])
    averaged_batch_out = critic(averaged_batch)
    critic_training = Model(inputs=[generator_in_critic_training, shower_in_critic_training], outputs=[out_critic_training_gen, out_critic_training_shower, averaged_batch_out])
    gradient_penalty = partial(utils.gradient_penalty_loss, averaged_batch=averaged_batch, penalty_weight=GRADIENT_PENALTY_WEIGHT)
    gradient_penalty.__name__ = 'gradient_penalty'
    critic_training.compile(optimizer=Adam(0.0001, beta_1=0.5, beta_2=0.9, decay=0.0), loss=[utils.wasserstein_loss, utils.wasserstein_loss, gradient_penalty])
    plot_model(critic_training, to_file=log_dir + '/critic_training.png', show_shapes=True)
    
    # For Wassersteinloss
    positive_y = np.ones(BATCH_SIZE)
    negative_y = -positive_y
    dummy = np.zeros(BATCH_SIZE) # keras throws an error when calculating a loss withuot having a label -> needed for using the gradient penalty loss
    
    generator_loss = []
    critic_loss = []
    
    # trainings loop
    iterations_per_epoch = N//((NCR+1)*BATCH_SIZE)
    for epoch in range(EPOCHS):
        print "epoch: ", epoch
        generated_map = generator.predict_on_batch(np.random.randn(BATCH_SIZE, latent_size))
        utils.plot_showermaps(generated_map[0,:,:,:], title = 'Generated Events Epoch:', epoch = str(epoch), log_dir=log_dir)
        for iteration in range(iterations_per_epoch):
            for j in range(NCR):
                noise_batch = np.random.randn(BATCH_SIZE, latent_size) # generate noise batch for generator
                shower_batch = shower_maps[BATCH_SIZE*(j+iteration):BATCH_SIZE*(j++iteration+1)]
                critic_loss.append(critic_training.train_on_batch([noise_batch, shower_batch], [negative_y, positive_y, dummy]))
                print "critic loss:", critic_loss[-1]
            noise_batch = np.random.randn(BATCH_SIZE, latent_size) # generate noise batch for generator
            generator_loss.append(generator_training.train_on_batch([noise_batch], [positive_y]))
            print "generator loss:", generator_loss[-1]
    
    utils.plot_loss(critic_loss, name="critic", iterations_per_epoch=iterations_per_epoch, NCR=NCR, log_dir=log_dir)
    utils.plot_loss(generator_loss, name="generator", iterations_per_epoch=iterations_per_epoch, log_dir=log_dir)
    
    
    # plot some generated figures
    generated_map = generator.predict(np.random.randn(2000, latent_size))
    utils.plot_correlation(generated_map, log_dir=log_dir, name='generated')
    utils.plot_showermaps(generated_map[0,:,:,:], title = 'Generated Events Epoch:', epoch = 'Final', log_dir=log_dir)