diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index db54f5f834aadb1dae8535147ee5d3659a0735ea..8f6778a974cbb649b681b3c99a80f24a6c56977a 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -12,7 +12,7 @@ lint: - pip install flake8 --user - flake8 lbn.py test.py setup.py -unittests_py2: +unittest_tf1_py2: stage: test tags: - docker @@ -20,10 +20,26 @@ unittests_py2: script: - python -m unittest test -unittests_py3: +unittests_tf1_py3: stage: test tags: - docker image: tensorflow/tensorflow:1.13.1-py3 script: - python -m unittest test + +unittest_tf2_py2: + stage: test + tags: + - docker + image: tensorflow/tensorflow:2.0.0a0 + script: + - python -m unittest test + +unittests_tf2_py3: + stage: test + tags: + - docker + image: tensorflow/tensorflow:2.0.0a0-py3 + script: + - python -m unittest test diff --git a/README.md b/README.md index b246c7d9a90c34c2c856cf3805a4dd061b0ef5f7..202e13ca70e3c677122506e1fed9d5209a3a0a1c 100644 --- a/README.md +++ b/README.md @@ -8,6 +8,7 @@ Original repository: [git.rwth-aachen.de/3pia/lbn](https://git.rwth-aachen.de/3p ### Usage example ```python +import tensorflow as tf from lbn import LBN # initialize the LBN, set 10 combinations and pairwise boosting @@ -20,6 +21,25 @@ features = lbn(four_vectors) ... ``` +Or with TensorFlow 2 and Keras: + +```python +import tensorflow as tf +from lbn import LBN, LBNLayer + +# start a sequential model +model = tf.keras.models.Sequential() + +# add the LBN layer +model.add(LBNLayer(10, boost_mode=LBN.PAIRS)) + +# add a dense layer +model.add(tf.keras.layers.Dense(1024)) + +# continue builing and training the model +... +``` + ### Installation and dependencies @@ -29,23 +49,21 @@ Via [pip](https://pypi.python.org/pypi/lbn): pip install lbn ``` -NumPy and TensorFlow (1.X) are the only dependencies. - -TensorFlow 2.0 is not *yet* supported. +NumPy and TensorFlow are the only dependencies. Both TensorFlow v1 and v2 are supported. ### Testing -Tests should be run for Python 2 and 3. The following commands assume you are root directory of the LBN respository: +Tests should be run for Python 2 and 3 and for TensorFlow 1 and 2. The following commands assume you are root directory of the LBN respository: ```bash python -m unittest test -# or via docker, python 2 +# or via docker, python 2 and tf 1 docker run --rm -v `pwd`:/root/lbn -w /root/lbn tensorflow/tensorflow:1.13.1 python -m unittest test -# or via docker, python 3 -docker run --rm -v `pwd`:/root/lbn -w /root/lbn tensorflow/tensorflow:1.13.1-py3 python -m unittest test +# or via docker, python 3 and tf 2 +docker run --rm -v `pwd`:/root/lbn -w /root/lbn tensorflow/tensorflow:2.0.0a0-py3 python -m unittest test ``` diff --git a/lbn.py b/lbn.py index 0e31643f352163942588db8508ee2e5f354338f8..ea675dd25b2204b92293cb2c634458bce89294ee 100644 --- a/lbn.py +++ b/lbn.py @@ -13,7 +13,7 @@ __contact__ = "https://git.rwth-aachen.de/3pia/lbn" __email__ = "marcel.rieger@cern.ch" __version__ = "1.0.3" -__all__ = ["LBN", "FeatureFactoryBase", "FeatureFactory"] +__all__ = ["LBN", "LBNLayer", "FeatureFactoryBase", "FeatureFactory"] import os @@ -118,6 +118,9 @@ class LBN(object): else: raise ValueError("invalid batch_norm, should be bool or list/tuple of two bools") + # the keras batch normalization layer + self.batch_norm = None + # particle weights and settings self.particle_weights = particle_weights self.abs_particle_weights = abs_particle_weights @@ -210,7 +213,7 @@ class LBN(object): if self.features is None: return None - return self.features.shape[-1].value + return int(self.features.shape[-1]) def register_feature(self, func=None, **kwargs): """ @@ -300,8 +303,8 @@ class LBN(object): self.inputs = inputs # infer sizes - self.n_in = self.inputs.shape[1].value - self.n_dim = self.inputs.shape[2].value + self.n_in = int(self.inputs.shape[1]) + self.n_dim = int(self.inputs.shape[2]) if self.n_dim != 4: raise Exception("input dimension must be 4 to represent 4-vectors") @@ -509,16 +512,37 @@ class LBN(object): def build_norm(self): """ Applies simple batch normalization with floating averages to the output features using - ``tf.layers.batch_normalization``. Make sure to also run the operation returned by - ``tf.get_collection(tf.GraphKeys.UPDATE_OPS)`` during each train step. + ``tf.keras.layers.BatchNormalization``. """ - self._norm_features = tf.layers.batch_normalization( - self.features, + self.batch_norm = tf.keras.layers.BatchNormalization( axis=1, - training=self.is_training, center=self.batch_norm_center, scale=self.batch_norm_scale, ) + self._norm_features = self.batch_norm(self.features, training=self.is_training) + + +class LBNLayer(tf.keras.layers.Layer): + """ + Keras layer of the :py:class:`LBN` that forwards the standard interface of :py:meth:`__init__` + and py:meth:`__call__`. + + .. py:attribute:: lbn + type: LBN + + Reference to the internal :py:class:`LBN` instance that is initialized with the contructor + arguments of this class. + """ + + def __init__(self, *args, **kwargs): + super(LBNLayer, self).__init__() + + # create the LBN instalce + self.lbn = LBN(*args, **kwargs) + + def __call__(self, *args, **kwargs): + # forward to lbn.__call__ + return self.lbn(*args, **kwargs) class FeatureFactoryBase(object): diff --git a/test.py b/test.py index d40c75ba2047933ecf67aeb8fb82b02688fa2326..cc34e67d918e9b75e581fe5a797592c90febb345 100644 --- a/test.py +++ b/test.py @@ -8,15 +8,20 @@ LBN unit tests. __all__ = ["TestCase"] +import sys import unittest import numpy as np import tensorflow as tf -from lbn import LBN, FeatureFactory +from lbn import LBN, LBNLayer, FeatureFactory -# enable eager execution -tf.enable_eager_execution() + +PY3 = sys.version.startswith("3.") +TF2 = tf.__version__.startswith("2.") + +if not TF2: + tf.enable_eager_execution() class TestCase(unittest.TestCase): @@ -24,8 +29,15 @@ class TestCase(unittest.TestCase): def __init__(self, *args, **kwargs): super(TestCase, self).__init__(*args, **kwargs) + # fixate random seeds + np.random.seed(123) + if TF2: + tf.random.set_seed(123) + else: + tf.random.set_random_seed(123) + # create some four-vectors with fixed seed and batch size 2 - self.vectors = create_four_vectors((2, 10), seed=123) + self.vectors = create_four_vectors((2, 10)) self.vectors_t = tf.constant(self.vectors, dtype=tf.float32) # common feature set @@ -111,7 +123,7 @@ class TestCase(unittest.TestCase): "particles_px", "particles_py", "particles_pz", "particles_pvec", "particles", "restframes_E", "restframes_px", "restframes_py", "restframes_pz", "restframes_pvec", "restframes", "Lambda", "boosted_particles", "_raw_features", "_norm_features", - "features", + "features", "batch_norm", ] lbn = LBN(10, boost_mode=LBN.PAIRS, is_training=True) @@ -315,6 +327,32 @@ class TestCase(unittest.TestCase): # test the custom feature self.assertAlmostEqual(lbn.feature_factory.px_plus_py().numpy()[1, 0], -36.780174, 4) + def test_keras_layer(self): + l = LBNLayer(10, boost_mode=LBN.PAIRS, batch_norm=True, is_training=True) + self.assertIsInstance(l.lbn, LBN) + self.assertTrue(l.lbn.batch_norm_center) + + # build a custom model + class Model(tf.keras.models.Model): + + def __init__(self): + super(Model, self).__init__() + + self.lbn = l + self.dense = tf.keras.layers.Dense(1024, activation="elu") + self.softmax = tf.keras.layers.Dense(2, activation="softmax") + + def __call__(self, *args, **kwargs): + return self.softmax(self.dense(self.lbn(*args, **kwargs))) + + model = Model() + output = model(self.vectors_t, features=self.feature_set).numpy() + + self.assertAlmostEqual(output[0, 0], 0.548664 if PY3 else 0.795995, 5) + self.assertAlmostEqual(output[0, 1], 0.451337 if PY3 else 0.204005, 5) + self.assertAlmostEqual(output[1, 0], 0.394629 if PY3 else 0.177576, 5) + self.assertAlmostEqual(output[1, 1], 0.605371 if PY3 else 0.822424, 5) + def create_four_vectors(n, p_low=-100., p_high=100., m_low=0.1, m_high=50., seed=None): """