Skip to content
Snippets Groups Projects
Commit f56cb424 authored by Dennis Noll's avatar Dennis Noll
Browse files

[keras] layers: added get_config for saving and loading

parent cee755b7
No related branches found
No related tags found
No related merge requests found
......@@ -846,6 +846,12 @@ class DenseLayer(tf.keras.layers.Layer):
def __init__(self, nodes=0, activation=None, dropout=0.0, l2=0, batch_norm=False):
super().__init__()
self.nodes = nodes
self.activation = activation
self.dropout = dropout
self.l2 = l2
self.batch_norm = batch_norm
parts = []
l2 = tf.keras.regularizers.l2(l2 if l2 else 0.0)
......@@ -874,6 +880,15 @@ class DenseLayer(tf.keras.layers.Layer):
x = part(x, training=training)
return x
def get_config(self):
return {
"nodes": self.nodes,
"activation": self.activation,
"dropout": self.dropout,
"l2": self.l2,
"batch_norm": self.batch_norm,
}
class ResNetBlock(tf.keras.layers.Layer):
"""
......@@ -890,9 +905,9 @@ class ResNetBlock(tf.keras.layers.Layer):
def __init__(self, config, jump=2, **kwargs):
super().__init__(name="ResNetBlock")
self.jump = jump
layers = []
for i in range(jump - 1):
for i in range(self.jump - 1):
layers.append(DenseLayer(**kwargs))
activation = kwargs.pop("activation")
......@@ -908,6 +923,9 @@ class ResNetBlock(tf.keras.layers.Layer):
x = self.out_activation(x)
return x
def get_config(self):
return {"jump": self.jump}
class FullyConnected(tf.keras.layers.Layer):
"""
......@@ -924,9 +942,9 @@ class FullyConnected(tf.keras.layers.Layer):
def __init__(self, number_layers=0, **kwargs):
super().__init__(name="FullyConnected")
self.number_layers = number_layers
layers = []
for layer in range(number_layers):
for layer in range(self.number_layers):
layers.append(DenseLayer(**kwargs))
self.layers = layers
......@@ -936,6 +954,9 @@ class FullyConnected(tf.keras.layers.Layer):
x = layer(x, training=training)
return x
def get_config(self):
return {"number_layers": self.number_layers}
class ResNet(tf.keras.layers.Layer):
"""
......@@ -952,9 +973,9 @@ class ResNet(tf.keras.layers.Layer):
def __init__(self, number_layers=1, **kwargs):
super().__init__(name="ResNet")
self.number_layers = number_layers
layers = []
for i in range(number_layers):
for i in range(self.number_layers):
layers.append(ResNetBlock(**kwargs))
self.layers = layers
......@@ -963,3 +984,8 @@ class ResNet(tf.keras.layers.Layer):
for layer in self.layers:
x = layer(x, training=training)
return x
def get_config(self):
return {"number_layers": self.number_layers}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment