From 3b6be777b2eb3ff8f3ba9ad0e5853ca362338db5 Mon Sep 17 00:00:00 2001 From: Dennis Noll <dennis.noll@rwth-aachen.de> Date: Fri, 31 Jul 2020 14:44:32 +0200 Subject: [PATCH] [keras] custom layers: now uses build function everywhere --- keras.py | 51 +++++++++++++++++++++++++++++---------------------- 1 file changed, 29 insertions(+), 22 deletions(-) diff --git a/keras.py b/keras.py index 16c94fe..8b9fcda 100644 --- a/keras.py +++ b/keras.py @@ -579,7 +579,7 @@ class CheckpointModel(tf.keras.callbacks.Callback): return f"{self.savedir}/{self.identifier}-{self.get_index(epoch)}" def on_epoch_end(self, epoch, logs=None): - if epoch != 0 and epoch % self.frequency == 0: + if epoch % self.frequency == 0: self.model.save(self.checkpoint_dir(epoch)) @@ -906,21 +906,22 @@ class DenseLayer(tf.keras.layers.Layer): self.l2 = l2 self.batch_norm = batch_norm + def build(self, input_shape): parts = [] - l2 = tf.keras.regularizers.l2(l2 if l2 else 0.0) - weights = tf.keras.layers.Dense(nodes, kernel_regularizer=l2) + l2 = tf.keras.regularizers.l2(self.l2) + weights = tf.keras.layers.Dense(self.nodes, kernel_regularizer=l2) parts.append(weights) - if batch_norm: + if self.batch_norm: dropout = 0.0 bn = tf.keras.layers.BatchNormalization() parts.append(bn) - act = tf.keras.layers.Activation(activation) + act = tf.keras.layers.Activation(self.activation) parts.append(act) - if activation == "selu": + if self.activation == "selu": dropout = tf.keras.layers.AlphaDropout(dropout) else: dropout = tf.keras.layers.Dropout(dropout) @@ -987,29 +988,32 @@ class FullyConnected(tf.keras.layers.Layer): Parameters ---------- - number_layers : int + layers : int The number of layers. kwargs : Arguments for DenseLayer. """ - def __init__(self, number_layers=0, **kwargs): + def __init__(self, layers=0, **kwargs): super().__init__(name="FullyConnected") - self.number_layers = number_layers - layers = [] - for layer in range(self.number_layers): - layers.append(DenseLayer(**kwargs)) self.layers = layers + self.kwargs = kwargs + + def build(self, input_shape): + _layers = [] + for layer in range(self.layers): + _layers.append(DenseLayer(**self.kwargs)) + self._layers = _layers def call(self, input_tensor, training=False): x = input_tensor - for layer in self.layers: + for layer in self._layers: x = layer(x, training=training) return x def get_config(self): - return {"number_layers": self.number_layers} + return {"layers": self.layers} class ResNet(tf.keras.layers.Layer): @@ -1018,29 +1022,32 @@ class ResNet(tf.keras.layers.Layer): Parameters ---------- - number_layers : int + layers : int The number of residual blocks. kwargs : Arguments for ResNetBlock. """ - def __init__(self, number_layers=1, **kwargs): + def __init__(self, layers=1, **kwargs): super().__init__(name="ResNet") - self.number_layers = number_layers - layers = [] - for i in range(self.number_layers): - layers.append(ResNetBlock(**kwargs)) self.layers = layers + self.kwargs = kwargs + + def build(self, input_shape): + _layers = [] + for i in range(self.layers): + _layers.append(ResNetBlock(**self.kwargs)) + self._layers = _layers def call(self, input_tensor, training=False): x = input_tensor - for layer in self.layers: + for layer in self._layers: x = layer(x, training=training) return x def get_config(self): - return {"number_layers": self.number_layers} + return {"layers": self.layers} class RemoveLayer(tf.keras.layers.Layer): -- GitLab