diff --git a/keras.py b/keras.py index 9b78cdd383b46f6fe9483f472e8bedeaf97e5b52..9ae24da70571504dfbccd3ae65d7abd224d02830 100644 --- a/keras.py +++ b/keras.py @@ -975,13 +975,20 @@ class DenseLayer(tf.keras.layers.Layer): def build(self, input_shape): parts = [] + if self.activation == "selu": + kernel_initializer = tf.variance_scaling_initializer(factor=1.0, mode="FAN_IN") + else: + kernel_initializer = "glorot_uniform" l2 = tf.keras.regularizers.l2(self.l2) - weights = tf.keras.layers.Dense(self.nodes, kernel_regularizer=l2) + weights = tf.keras.layers.Dense( + self.nodes, + kernel_regularizer=l2, + kernel_initializer=kernel_initializer, + ) parts.append(weights) - if self.batch_norm: - self.dropout = 0.0 + if self.batch_norm and not self.activation == "selu": bn = tf.keras.layers.BatchNormalization() parts.append(bn)