diff --git a/keras.py b/keras.py index 77a0ab84cafada0d41f9ae1dc0ceae0b61cbf7b7..3c1c3dbea00038a076df678f1640e09c4c6be2bb 100644 --- a/keras.py +++ b/keras.py @@ -1016,6 +1016,16 @@ class ResNet(tf.keras.layers.Layer): return {"number_layers": self.number_layers} +class RemoveLayer(tf.keras.layers.Layer): + def call(self, inputs): + return tf.keras.layers.Concatenate()([inputs[:, :, :4], inputs[:, :, 6:]]) + + +class SplitHighLow(tf.keras.layers.Layer): + def call(self, inputs): + return inputs[:, :, :4], inputs[:, :, 4:] + + class LBNLayer(tf.keras.layers.Layer): """ Custom implementation of the LBNLayer with automatic cropping to @@ -1064,14 +1074,8 @@ class LBNLayer(tf.keras.layers.Layer): # "pair_dy", ] - def split_low_high(self, tensor): - def function(tensor): - return tensor[:, :, :4], tensor[:, :, 4:] - - return tf.keras.layers.Lambda(function)(tensor) - def call(self, input_tensors, training=False): - ll = [self.split_low_high(tensor)[0] for tensor in input_tensors] + ll = [SplitHighLow()(tensor)[0] for tensor in input_tensors] ll = self.concat(ll) feats = self.lbn_layer(ll) feats = self.batch_norm(feats)