From 3b6be777b2eb3ff8f3ba9ad0e5853ca362338db5 Mon Sep 17 00:00:00 2001
From: Dennis Noll <dennis.noll@rwth-aachen.de>
Date: Fri, 31 Jul 2020 14:44:32 +0200
Subject: [PATCH] [keras] custom layers: now uses build function everywhere

---
 keras.py | 51 +++++++++++++++++++++++++++++----------------------
 1 file changed, 29 insertions(+), 22 deletions(-)

diff --git a/keras.py b/keras.py
index 16c94fe..8b9fcda 100644
--- a/keras.py
+++ b/keras.py
@@ -579,7 +579,7 @@ class CheckpointModel(tf.keras.callbacks.Callback):
         return f"{self.savedir}/{self.identifier}-{self.get_index(epoch)}"
 
     def on_epoch_end(self, epoch, logs=None):
-        if epoch != 0 and epoch % self.frequency == 0:
+        if epoch % self.frequency == 0:
             self.model.save(self.checkpoint_dir(epoch))
 
 
@@ -906,21 +906,22 @@ class DenseLayer(tf.keras.layers.Layer):
         self.l2 = l2
         self.batch_norm = batch_norm
 
+    def build(self, input_shape):
         parts = []
 
-        l2 = tf.keras.regularizers.l2(l2 if l2 else 0.0)
-        weights = tf.keras.layers.Dense(nodes, kernel_regularizer=l2)
+        l2 = tf.keras.regularizers.l2(self.l2)
+        weights = tf.keras.layers.Dense(self.nodes, kernel_regularizer=l2)
         parts.append(weights)
 
-        if batch_norm:
+        if self.batch_norm:
             dropout = 0.0
             bn = tf.keras.layers.BatchNormalization()
             parts.append(bn)
 
-        act = tf.keras.layers.Activation(activation)
+        act = tf.keras.layers.Activation(self.activation)
         parts.append(act)
 
-        if activation == "selu":
+        if self.activation == "selu":
             dropout = tf.keras.layers.AlphaDropout(dropout)
         else:
             dropout = tf.keras.layers.Dropout(dropout)
@@ -987,29 +988,32 @@ class FullyConnected(tf.keras.layers.Layer):
 
     Parameters
     ----------
-    number_layers : int
+    layers : int
         The number of layers.
     kwargs :
         Arguments for DenseLayer.
 
     """
 
-    def __init__(self, number_layers=0, **kwargs):
+    def __init__(self, layers=0, **kwargs):
         super().__init__(name="FullyConnected")
-        self.number_layers = number_layers
-        layers = []
-        for layer in range(self.number_layers):
-            layers.append(DenseLayer(**kwargs))
         self.layers = layers
+        self.kwargs = kwargs
+
+    def build(self, input_shape):
+        _layers = []
+        for layer in range(self.layers):
+            _layers.append(DenseLayer(**self.kwargs))
+        self._layers = _layers
 
     def call(self, input_tensor, training=False):
         x = input_tensor
-        for layer in self.layers:
+        for layer in self._layers:
             x = layer(x, training=training)
         return x
 
     def get_config(self):
-        return {"number_layers": self.number_layers}
+        return {"layers": self.layers}
 
 
 class ResNet(tf.keras.layers.Layer):
@@ -1018,29 +1022,32 @@ class ResNet(tf.keras.layers.Layer):
 
     Parameters
     ----------
-    number_layers : int
+    layers : int
         The number of residual blocks.
     kwargs :
         Arguments for ResNetBlock.
 
     """
 
-    def __init__(self, number_layers=1, **kwargs):
+    def __init__(self, layers=1, **kwargs):
         super().__init__(name="ResNet")
-        self.number_layers = number_layers
-        layers = []
-        for i in range(self.number_layers):
-            layers.append(ResNetBlock(**kwargs))
         self.layers = layers
+        self.kwargs = kwargs
+
+    def build(self, input_shape):
+        _layers = []
+        for i in range(self.layers):
+            _layers.append(ResNetBlock(**self.kwargs))
+        self._layers = _layers
 
     def call(self, input_tensor, training=False):
         x = input_tensor
-        for layer in self.layers:
+        for layer in self._layers:
             x = layer(x, training=training)
         return x
 
     def get_config(self):
-        return {"number_layers": self.number_layers}
+        return {"layers": self.layers}
 
 
 class RemoveLayer(tf.keras.layers.Layer):
-- 
GitLab