Skip to content
Snippets Groups Projects
Commit 3b6be777 authored by Dennis Noll's avatar Dennis Noll
Browse files

[keras] custom layers: now uses build function everywhere

parent 5de2bb08
No related branches found
No related tags found
No related merge requests found
......@@ -579,7 +579,7 @@ class CheckpointModel(tf.keras.callbacks.Callback):
return f"{self.savedir}/{self.identifier}-{self.get_index(epoch)}"
def on_epoch_end(self, epoch, logs=None):
if epoch != 0 and epoch % self.frequency == 0:
if epoch % self.frequency == 0:
self.model.save(self.checkpoint_dir(epoch))
......@@ -906,21 +906,22 @@ class DenseLayer(tf.keras.layers.Layer):
self.l2 = l2
self.batch_norm = batch_norm
def build(self, input_shape):
parts = []
l2 = tf.keras.regularizers.l2(l2 if l2 else 0.0)
weights = tf.keras.layers.Dense(nodes, kernel_regularizer=l2)
l2 = tf.keras.regularizers.l2(self.l2)
weights = tf.keras.layers.Dense(self.nodes, kernel_regularizer=l2)
parts.append(weights)
if batch_norm:
if self.batch_norm:
dropout = 0.0
bn = tf.keras.layers.BatchNormalization()
parts.append(bn)
act = tf.keras.layers.Activation(activation)
act = tf.keras.layers.Activation(self.activation)
parts.append(act)
if activation == "selu":
if self.activation == "selu":
dropout = tf.keras.layers.AlphaDropout(dropout)
else:
dropout = tf.keras.layers.Dropout(dropout)
......@@ -987,29 +988,32 @@ class FullyConnected(tf.keras.layers.Layer):
Parameters
----------
number_layers : int
layers : int
The number of layers.
kwargs :
Arguments for DenseLayer.
"""
def __init__(self, number_layers=0, **kwargs):
def __init__(self, layers=0, **kwargs):
super().__init__(name="FullyConnected")
self.number_layers = number_layers
layers = []
for layer in range(self.number_layers):
layers.append(DenseLayer(**kwargs))
self.layers = layers
self.kwargs = kwargs
def build(self, input_shape):
_layers = []
for layer in range(self.layers):
_layers.append(DenseLayer(**self.kwargs))
self._layers = _layers
def call(self, input_tensor, training=False):
x = input_tensor
for layer in self.layers:
for layer in self._layers:
x = layer(x, training=training)
return x
def get_config(self):
return {"number_layers": self.number_layers}
return {"layers": self.layers}
class ResNet(tf.keras.layers.Layer):
......@@ -1018,29 +1022,32 @@ class ResNet(tf.keras.layers.Layer):
Parameters
----------
number_layers : int
layers : int
The number of residual blocks.
kwargs :
Arguments for ResNetBlock.
"""
def __init__(self, number_layers=1, **kwargs):
def __init__(self, layers=1, **kwargs):
super().__init__(name="ResNet")
self.number_layers = number_layers
layers = []
for i in range(self.number_layers):
layers.append(ResNetBlock(**kwargs))
self.layers = layers
self.kwargs = kwargs
def build(self, input_shape):
_layers = []
for i in range(self.layers):
_layers.append(ResNetBlock(**self.kwargs))
self._layers = _layers
def call(self, input_tensor, training=False):
x = input_tensor
for layer in self.layers:
for layer in self._layers:
x = layer(x, training=training)
return x
def get_config(self):
return {"number_layers": self.number_layers}
return {"layers": self.layers}
class RemoveLayer(tf.keras.layers.Layer):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment