Skip to content
Snippets Groups Projects
Commit efb7e16f authored by Dennis Noll's avatar Dennis Noll
Browse files

[keras] Networks: now all LinearNetworks come from one class

parent 30b622ea
No related branches found
No related tags found
No related merge requests found
......@@ -1117,28 +1117,24 @@ class DenseNetBlock(tf.keras.layers.Layer):
return {"block_size": self.block_size, "sub_kwargs": self.sub_kwargs}
class FullyConnected(tf.keras.layers.Layer):
"""
The FullyConnected object is an implementation of a fully connected DNN.
Parameters
----------
layers : int
The number of layers.
kwargs :
Arguments for DenseLayer.
class LinearNetwork(tf.keras.layers.Layer):
@property
def name(self):
raise NotImplementedError
"""
@property
def substructure(self):
raise NotImplementedError
def __init__(self, layers=0, sub_kwargs=None, **kwargs):
super().__init__(name="FullyConnected")
super().__init__(name=self.name)
self.layers = layers
self.sub_kwargs = kwargs if sub_kwargs is None else sub_kwargs
def build(self, input_shape):
network_layers = []
for layer in range(self.layers):
network_layers.append(DenseLayer(**self.sub_kwargs))
network_layers.append(self.substructure(**self.sub_kwargs))
self.network_layers = network_layers
def call(self, input_tensor, training=False):
......@@ -1151,41 +1147,41 @@ class FullyConnected(tf.keras.layers.Layer):
return {"layers": self.layers, "sub_kwargs": self.sub_kwargs}
class ResNet(tf.keras.layers.Layer):
class FullyConnected(LinearNetwork):
"""
The ResNet object is an implementation of a Residual Neural Network.
The FullyConnected object is an implementation of a fully connected DNN.
Parameters
----------
layers : int
The number of residual blocks.
The number of layers.
kwargs :
Arguments for ResNetBlock.
Arguments for DenseLayer.
"""
def __init__(self, layers=1, sub_kwargs=None, **kwargs):
super().__init__(name="ResNet")
self.layers = layers
self.sub_kwargs = kwargs if sub_kwargs is None else sub_kwargs
name = "FullyConnected"
substructure = DenseLayer
def build(self, input_shape):
_layers = []
for i in range(self.layers):
_layers.append(ResNetBlock(**self.sub_kwargs))
self._layers = _layers
def call(self, input_tensor, training=False):
x = input_tensor
for layer in self._layers:
x = layer(x, training=training)
return x
class ResNet(LinearNetwork):
"""
The ResNet object is an implementation of a Residual Neural Network.
def get_config(self):
return {"layers": self.layers, "sub_kwargs": self.sub_kwargs}
Parameters
----------
layers : int
The number of residual blocks.
kwargs :
Arguments for ResNetBlock.
"""
name = "ResNet"
substructure = ResNetBlock
class DenseNet(tf.keras.layers.Layer):
class DenseNet(LinearNetwork):
"""
The DenseNet object is an implementation of a DenseNet Neural Network.
......@@ -1198,25 +1194,8 @@ class DenseNet(tf.keras.layers.Layer):
"""
def __init__(self, layers=1, sub_kwargs=None, **kwargs):
super().__init__(name="DenseNet")
self.layers = layers
self.sub_kwargs = kwargs if sub_kwargs is None else sub_kwargs
def build(self, input_shape):
_layers = []
for i in range(self.layers):
_layers.append(DenseNetBlock(**self.sub_kwargs))
self._layers = _layers
def call(self, input_tensor, training=False):
x = input_tensor
for layer in self._layers:
x = layer(x, training=training)
return x
def get_config(self):
return {"layers": self.layers, "sub_kwargs": self.sub_kwargs}
name = "DenseNet"
substructure = DenseNetBlock
class Xception1D(tf.keras.layers.Layer):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment