diff --git a/keras.py b/keras.py index 8b9fcda43df4dc7a60910003fe029fb1e115dd44..b8cebe697938248fbf9c46b4bf1535b9273a4adb 100644 --- a/keras.py +++ b/keras.py @@ -899,7 +899,7 @@ class DenseLayer(tf.keras.layers.Layer): """ def __init__(self, nodes=0, activation=None, dropout=0.0, l2=0, batch_norm=False): - super().__init__() + super().__init__(name="DenseLayer") self.nodes = nodes self.activation = activation self.dropout = dropout @@ -914,7 +914,7 @@ class DenseLayer(tf.keras.layers.Layer): parts.append(weights) if self.batch_norm: - dropout = 0.0 + self.dropout = 0.0 bn = tf.keras.layers.BatchNormalization() parts.append(bn) @@ -922,9 +922,9 @@ class DenseLayer(tf.keras.layers.Layer): parts.append(act) if self.activation == "selu": - dropout = tf.keras.layers.AlphaDropout(dropout) + dropout = tf.keras.layers.AlphaDropout(self.dropout) else: - dropout = tf.keras.layers.Dropout(dropout) + dropout = tf.keras.layers.Dropout(self.dropout) parts.append(dropout) self.parts = parts @@ -995,25 +995,25 @@ class FullyConnected(tf.keras.layers.Layer): """ - def __init__(self, layers=0, **kwargs): + def __init__(self, layers=0, sub_kwargs=None, **kwargs): super().__init__(name="FullyConnected") self.layers = layers - self.kwargs = kwargs + self.sub_kwargs = kwargs if sub_kwargs is None else sub_kwargs def build(self, input_shape): - _layers = [] + network_layers = [] for layer in range(self.layers): - _layers.append(DenseLayer(**self.kwargs)) - self._layers = _layers + network_layers.append(DenseLayer(**self.sub_kwargs)) + self.network_layers = network_layers def call(self, input_tensor, training=False): x = input_tensor - for layer in self._layers: + for layer in self.network_layers: x = layer(x, training=training) return x def get_config(self): - return {"layers": self.layers} + return {"layers": self.layers, "sub_kwargs": self.sub_kwargs} class ResNet(tf.keras.layers.Layer):