[lbn] LBNLayer: now can load LBNLayer with tf.keras.models.load_model

parent d3243835
Pipeline #259342 failed with stages
......@@ -75,7 +75,7 @@ class LBN(object):
def __init__(self, n_particles, n_restframes=None, boost_mode=PAIRS, feature_factory=None,
particle_weights=None, abs_particle_weights=True, clip_particle_weights=False,
restframe_weights=None, abs_restframe_weights=True, clip_restframe_weights=False,
weight_init=None, epsilon=1e-5, name=None):
weight_init=None, epsilon=1e-5, name=None, **kwargs):
super(LBN, self).__init__()
# determine the number of output particles, which depends on the boost mode
......@@ -502,8 +502,8 @@ class LBNLayer(tf.keras.layers.Layer):
arguments of this class.
"""
def __init__(self, *args, **kwargs):
super(LBNLayer, self).__init__()
def __init__(self, n_particles, *args, **kwargs):
self.n_particles = n_particles
# store names of features to build
self.feature_names = kwargs.pop("features", None)
......@@ -512,7 +512,8 @@ class LBNLayer(tf.keras.layers.Layer):
self.seed = kwargs.pop("seed", None)
# create the LBN instance with the remaining arguments
self.lbn = LBN(*args, **kwargs)
self.lbn = LBN(n_particles, *args, **kwargs)
super(LBNLayer, self).__init__()
def build(self, input_shape):
# get the number of input vectors
......@@ -549,6 +550,15 @@ class LBNLayer(tf.keras.layers.Layer):
def compute_output_shape(self, input_shape):
return (input_shape[0], self.lbn.n_features)
def get_config(self):
config = {
"n_particles": self.n_particles,
"features": self.feature_names,
"seed": self.seed,
}
base_config = super(LBNLayer, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
class FeatureFactoryBase(object):
"""
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment