Commit 9896361d authored by Marcel Rieger's avatar Marcel Rieger

Add keras layer and model tests.

parent 8b179f4f
Pipeline #118287 failed with stages
in 34 seconds
......@@ -13,7 +13,7 @@ import unittest
import numpy as np
import tensorflow as tf
from lbn import LBN, FeatureFactory
from lbn import LBN, LBNLayer, FeatureFactory
class TestCase(unittest.TestCase):
......@@ -21,8 +21,12 @@ class TestCase(unittest.TestCase):
def __init__(self, *args, **kwargs):
super(TestCase, self).__init__(*args, **kwargs)
# fixate random seeds
np.random.seed(123)
tf.random.set_seed(123)
# create some four-vectors with fixed seed and batch size 2
self.vectors = create_four_vectors((2, 10), seed=123)
self.vectors = create_four_vectors((2, 10))
self.vectors_t = tf.constant(self.vectors, dtype=tf.float32)
# common feature set
......@@ -108,7 +112,7 @@ class TestCase(unittest.TestCase):
"particles_px", "particles_py", "particles_pz", "particles_pvec", "particles",
"restframes_E", "restframes_px", "restframes_py", "restframes_pz", "restframes_pvec",
"restframes", "Lambda", "boosted_particles", "_raw_features", "_norm_features",
"features",
"features", "batch_norm",
]
lbn = LBN(10, boost_mode=LBN.PAIRS, is_training=True)
......@@ -312,6 +316,32 @@ class TestCase(unittest.TestCase):
# test the custom feature
self.assertAlmostEqual(lbn.feature_factory.px_plus_py().numpy()[1, 0], -36.780174, 4)
def test_keras_layer(self):
l = LBNLayer(10, boost_mode=LBN.PAIRS, batch_norm=True, is_training=True)
self.assertIsInstance(l.lbn, LBN)
self.assertTrue(l.lbn.batch_norm_center)
# build a custom model
class Model(tf.keras.models.Model):
def __init__(self):
super(Model, self).__init__()
self.lbn = l
self.dense = tf.keras.layers.Dense(1024, activation="elu")
self.softmax = tf.keras.layers.Dense(2, activation="softmax")
def __call__(self, *args, **kwargs):
return self.softmax(self.dense(self.lbn(*args, **kwargs)))
model = Model()
output = model(self.vectors_t, features=self.feature_set).numpy()
self.assertAlmostEqual(output[0, 0], 0.321891, 5)
self.assertAlmostEqual(output[0, 1], 0.678109, 5)
self.assertAlmostEqual(output[1, 0], 0.625410, 5)
self.assertAlmostEqual(output[1, 1], 0.374590, 5)
def create_four_vectors(n, p_low=-100., p_high=100., m_low=0.1, m_high=50., seed=None):
"""
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment