From 882f60d2ed9057e05b4b73de12ccd32bfc02e6d2 Mon Sep 17 00:00:00 2001
From: Peter Fackeldey <peter.fackeldey@rwth-aachen.de>
Date: Fri, 19 Feb 2021 16:26:47 +0100
Subject: [PATCH] add permutation inveriant dense prototype layer

---
 keras.py | 56 +++++++++++++++++++++++++++++++++++++++++++++++++++++++-
 1 file changed, 55 insertions(+), 1 deletion(-)

diff --git a/keras.py b/keras.py
index 614688c..3625349 100644
--- a/keras.py
+++ b/keras.py
@@ -450,7 +450,6 @@ class PlotMulticlass(PlottingCallback, TFSummaryCallback):
     def on_train_end(self, logs=None):
         self.make_eval_plots()
 
-
     def make_input_plots(self):
         inps = self.x
         if not isinstance(inps, (list, tuple)):
@@ -1333,6 +1332,61 @@ class Xception1D(tf.keras.layers.Layer):
         return {"sub_kwargs": self.sub_kwargs}
 
 
+class PermutationInvariantDense1D(tf.keras.layers.Layer):
+    """
+    The PermutationInvariantDense1D layer is a TimeDistributed(Dense) implementation.
+    Additionally it imposes (if wanted) a permutation inivariant pooling operation.
+
+    Parameters
+    ----------
+    name :
+        name of the layer
+    depth :
+        depth of the ResNet layers
+    nfeatures:
+        number of features of Conv1D layer
+    batchnorm:
+        enable batchnorm for ResNet layers
+    pooling_op:
+        permutation invariant pooling operation
+    """
+
+    def __init__(self, **kwargs):
+        name = kwargs.pop("name", "PermutationInvariantDense1D")
+        super().__init__(name=name)
+        self.depth = kwargs.pop("depth", 1)
+        self.nfeatures = kwargs.pop("nfeatures", 32)
+        self.batchnorm = kwargs.pop("batchnorm", False)
+        self.pooling_op = kwargs.pop("pooling_op", None)
+
+    def build(self, input_shape):
+        opts = dict(activation="elu", kernel_initializer="he_normal")
+        self.layer = tf.keras.layers.TimeDistributed(
+            tf.keras.layers.Dense(
+                self.nfeatures,
+                input_shape=input_shape,
+                **opts,
+            )
+        )
+        network_layers = []
+        for layer in range(self.depth - 1):
+            network_layers.append(
+                tf.keras.layers.TimeDistributed(tf.keras.layers.Dense(self.nfeatures, **opts))
+            )
+            if self.batchnorm:
+                network_layers.append(tf.keras.layers.BatchNormalization())
+        if self.pooling_op:
+            # permutation invariant pooling op
+            network_layers.append(tf.keras.layers.Lambda(lambda x: self.pooling_op(x, axis=1)))
+        self.network_layers = network_layers
+
+    def call(self, input_tensor, training=False):
+        x = input_tensor
+        for layer in self.network_layers:
+            x = layer(x, training=training)
+        return x
+
+
 class SplitHighLow(tf.keras.layers.Layer):
     def call(self, inputs):
         return inputs[:, :, :4], inputs[:, :, 4:]
-- 
GitLab