diff --git a/keras.py b/keras.py
index 77a0ab84cafada0d41f9ae1dc0ceae0b61cbf7b7..3c1c3dbea00038a076df678f1640e09c4c6be2bb 100644
--- a/keras.py
+++ b/keras.py
@@ -1016,6 +1016,16 @@ class ResNet(tf.keras.layers.Layer):
         return {"number_layers": self.number_layers}
 
 
+class RemoveLayer(tf.keras.layers.Layer):
+    def call(self, inputs):
+        return tf.keras.layers.Concatenate()([inputs[:, :, :4], inputs[:, :, 6:]])
+
+
+class SplitHighLow(tf.keras.layers.Layer):
+    def call(self, inputs):
+        return inputs[:, :, :4], inputs[:, :, 4:]
+
+
 class LBNLayer(tf.keras.layers.Layer):
     """
     Custom implementation of the LBNLayer with automatic cropping to
@@ -1064,14 +1074,8 @@ class LBNLayer(tf.keras.layers.Layer):
             # "pair_dy",
         ]
 
-    def split_low_high(self, tensor):
-        def function(tensor):
-            return tensor[:, :, :4], tensor[:, :, 4:]
-
-        return tf.keras.layers.Lambda(function)(tensor)
-
     def call(self, input_tensors, training=False):
-        ll = [self.split_low_high(tensor)[0] for tensor in input_tensors]
+        ll = [SplitHighLow()(tensor)[0] for tensor in input_tensors]
         ll = self.concat(ll)
         feats = self.lbn_layer(ll)
         feats = self.batch_norm(feats)