diff --git a/keras.py b/keras.py index 614688cd62339e35f36513928c72e9e5d575592f..3625349ac30082dc00b36d7b3c6af36a9eae0549 100644 --- a/keras.py +++ b/keras.py @@ -450,7 +450,6 @@ class PlotMulticlass(PlottingCallback, TFSummaryCallback): def on_train_end(self, logs=None): self.make_eval_plots() - def make_input_plots(self): inps = self.x if not isinstance(inps, (list, tuple)): @@ -1333,6 +1332,61 @@ class Xception1D(tf.keras.layers.Layer): return {"sub_kwargs": self.sub_kwargs} +class PermutationInvariantDense1D(tf.keras.layers.Layer): + """ + The PermutationInvariantDense1D layer is a TimeDistributed(Dense) implementation. + Additionally it imposes (if wanted) a permutation inivariant pooling operation. + + Parameters + ---------- + name : + name of the layer + depth : + depth of the ResNet layers + nfeatures: + number of features of Conv1D layer + batchnorm: + enable batchnorm for ResNet layers + pooling_op: + permutation invariant pooling operation + """ + + def __init__(self, **kwargs): + name = kwargs.pop("name", "PermutationInvariantDense1D") + super().__init__(name=name) + self.depth = kwargs.pop("depth", 1) + self.nfeatures = kwargs.pop("nfeatures", 32) + self.batchnorm = kwargs.pop("batchnorm", False) + self.pooling_op = kwargs.pop("pooling_op", None) + + def build(self, input_shape): + opts = dict(activation="elu", kernel_initializer="he_normal") + self.layer = tf.keras.layers.TimeDistributed( + tf.keras.layers.Dense( + self.nfeatures, + input_shape=input_shape, + **opts, + ) + ) + network_layers = [] + for layer in range(self.depth - 1): + network_layers.append( + tf.keras.layers.TimeDistributed(tf.keras.layers.Dense(self.nfeatures, **opts)) + ) + if self.batchnorm: + network_layers.append(tf.keras.layers.BatchNormalization()) + if self.pooling_op: + # permutation invariant pooling op + network_layers.append(tf.keras.layers.Lambda(lambda x: self.pooling_op(x, axis=1))) + self.network_layers = network_layers + + def call(self, input_tensor, training=False): + x = input_tensor + for layer in self.network_layers: + x = layer(x, training=training) + return x + + class SplitHighLow(tf.keras.layers.Layer): def call(self, inputs): return inputs[:, :, :4], inputs[:, :, 4:]