Commit 8b179f4f authored by Marcel Rieger's avatar Marcel Rieger

Add keras LBNLayer.

parent 17bbb86d
......@@ -13,7 +13,7 @@ __contact__ = "https://git.rwth-aachen.de/3pia/lbn"
__email__ = "marcel.rieger@cern.ch"
__version__ = "1.0.3"
__all__ = ["LBN", "FeatureFactoryBase", "FeatureFactory"]
__all__ = ["LBN", "LBNLayer", "FeatureFactoryBase", "FeatureFactory"]
import os
......@@ -522,6 +522,29 @@ class LBN(object):
self._norm_features = self.batch_norm(self.features, training=self.is_training)
class LBNLayer(tf.keras.layers.Layer):
"""
Keras layer of the :py:class:`LBN` that forwards the standard interface of :py:meth:`__init__`
and py:meth:`__call__`.
.. py:attribute:: lbn
type: LBN
Reference to the internal :py:class:`LBN` instance that is initialized with the contructor
arguments of this class.
"""
def __init__(self, *args, **kwargs):
super(LBNLayer, self).__init__()
# create the LBN instalce
self.lbn = LBN(*args, **kwargs)
def __call__(self, *args, **kwargs):
# forward to lbn.__call__
return self.lbn(*args, **kwargs)
class FeatureFactoryBase(object):
"""
Base class of the feature factory. It does not implement actual features but rather the
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment