Commit 8b179f4f authored by Marcel Rieger's avatar Marcel Rieger

Add keras LBNLayer.

parent 17bbb86d
...@@ -13,7 +13,7 @@ __contact__ = "https://git.rwth-aachen.de/3pia/lbn" ...@@ -13,7 +13,7 @@ __contact__ = "https://git.rwth-aachen.de/3pia/lbn"
__email__ = "marcel.rieger@cern.ch" __email__ = "marcel.rieger@cern.ch"
__version__ = "1.0.3" __version__ = "1.0.3"
__all__ = ["LBN", "FeatureFactoryBase", "FeatureFactory"] __all__ = ["LBN", "LBNLayer", "FeatureFactoryBase", "FeatureFactory"]
import os import os
...@@ -522,6 +522,29 @@ class LBN(object): ...@@ -522,6 +522,29 @@ class LBN(object):
self._norm_features = self.batch_norm(self.features, training=self.is_training) self._norm_features = self.batch_norm(self.features, training=self.is_training)
class LBNLayer(tf.keras.layers.Layer):
"""
Keras layer of the :py:class:`LBN` that forwards the standard interface of :py:meth:`__init__`
and py:meth:`__call__`.
.. py:attribute:: lbn
type: LBN
Reference to the internal :py:class:`LBN` instance that is initialized with the contructor
arguments of this class.
"""
def __init__(self, *args, **kwargs):
super(LBNLayer, self).__init__()
# create the LBN instalce
self.lbn = LBN(*args, **kwargs)
def __call__(self, *args, **kwargs):
# forward to lbn.__call__
return self.lbn(*args, **kwargs)
class FeatureFactoryBase(object): class FeatureFactoryBase(object):
""" """
Base class of the feature factory. It does not implement actual features but rather the Base class of the feature factory. It does not implement actual features but rather the
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment