diff --git a/mrcnn/parallel_model.py b/mrcnn/parallel_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..6ec9ef3d6f250e96d387dcaf539848842f563557
--- /dev/null
+++ b/mrcnn/parallel_model.py
@@ -0,0 +1,186 @@
+"""
+MRCNN Particle Detection
+Multi-GPU Support for Keras.
+
+The source code of "MRCNN Particle Detection" (https://git.rwth-aachen.de/avt-fvt/private/mrcnn-particle-detection) 
+is based on the source code of "Mask R-CNN" (https://github.com/matterport/Mask_RCNN).
+
+The source code of "Mask R-CNN" is licensed under the MIT License (MIT).
+Copyright (c) 2017 Matterport, Inc.
+Written by Waleed Abdulla
+
+All source code modifications to the source code of "Mask R-CNN" in "MRCNN Particle Detection" 
+are licensed under the Eclipse Public License v2.0 (EPL 2.0).
+Copyright (c) 2022-2023 Fluid Process Engineering (AVT.FVT), RWTH Aachen University
+Edited by Stepan Sibirtsev, Mathias Neufang & Jakob Seiler
+
+The coyprights and license terms are given in LICENSE.
+
+Ideas and a small code snippets were adapted from these sources:
+https://github.com/mat02/Mask_RCNN
+https://github.com/fchollet/keras/issues/2436
+https://medium.com/@kuza55/transparent-multi-gpu-training-on-tensorflow-with-keras-8b0016fd9012
+https://github.com/avolkov1/keras_experiments/blob/master/keras_exp/multigpu/
+https://github.com/fchollet/keras/blob/master/keras/utils/training_utils.py
+"""
+
+import tensorflow as tf
+import keras.backend as K
+import keras.layers as KL
+import keras.models as KM
+
+
+class ParallelModel(KM.Model):
+    """Subclasses the standard Keras Model and adds multi-GPU support.
+    It works by creating a copy of the model on each GPU. Then it slices
+    the inputs and sends a slice to each copy of the model, and then
+    merges the outputs together and applies the loss on the combined
+    outputs.
+    """
+
+    def __init__(self, keras_model, gpu_count):
+        """Class constructor.
+        keras_model: The Keras model to parallelize
+        gpu_count: Number of GPUs. Must be > 1
+        """
+        self.inner_model = keras_model
+        self.gpu_count = gpu_count
+        merged_outputs = self.make_parallel()
+        super(ParallelModel, self).__init__(inputs=self.inner_model.inputs,
+                                            outputs=merged_outputs)
+
+    def __getattribute__(self, attrname):
+        """Redirect loading and saving methods to the inner model. That's where
+        the weights are stored."""
+        if 'load' in attrname or 'save' in attrname:
+            return getattr(self.inner_model, attrname)
+        return super(ParallelModel, self).__getattribute__(attrname)
+
+    def summary(self, *args, **kwargs):
+        """Override summary() to display summaries of both, the wrapper
+        and inner models."""
+        super(ParallelModel, self).summary(*args, **kwargs)
+        self.inner_model.summary(*args, **kwargs)
+
+    def make_parallel(self):
+        """Creates a new wrapper model that consists of multiple replicas of
+        the original model placed on different GPUs.
+        """
+        # Slice inputs. Slice inputs on the CPU to avoid sending a copy
+        # of the full inputs to all GPUs. Saves on bandwidth and memory.
+        input_slices = {name: tf.split(x, self.gpu_count)
+                        for name, x in zip(self.inner_model.input_names,
+                                           self.inner_model.inputs)}
+
+        output_names = self.inner_model.output_names
+        outputs_all = []
+        for i in range(len(self.inner_model.outputs)):
+            outputs_all.append([])
+
+        # Run the model call() on each GPU to place the ops there
+        for i in range(self.gpu_count):
+            with tf.device('/gpu:%d' % i):
+                with tf.name_scope('tower_%d' % i):
+                    # Run a slice of inputs through this replica
+                    zipped_inputs = zip(self.inner_model.input_names,
+                                        self.inner_model.inputs)
+                    inputs = [
+                        KL.Lambda(lambda s: input_slices[name][i],
+                                  output_shape=lambda s: (None,) + s[1:])(tensor)
+                        for name, tensor in zipped_inputs]
+                    # Create the model replica and get the outputs
+                    outputs = self.inner_model(inputs)
+                    if not isinstance(outputs, list):
+                        outputs = [outputs]
+                    # Save the outputs for merging back together later
+                    for l, o in enumerate(outputs):
+                        outputs_all[l].append(o)
+
+        # Merge outputs on CPU
+        with tf.device('/cpu:0'):
+            merged = []
+            for outputs, name in zip(outputs_all, output_names):
+                # Concatenate or average outputs?
+                # Outputs usually have a batch dimension and we concatenate
+                # across it. If they don't, then the output is likely a loss
+                # or a metric value that gets averaged across the batch.
+                # Keras expects losses and metrics to be scalars.
+                if K.int_shape(outputs[0]) == ():
+                    # Average
+                    m = KL.Lambda(lambda o: tf.add_n(o) / len(outputs), name=name)(outputs)
+                else:
+                    # Concatenate
+                    m = KL.Concatenate(axis=0, name=name)(outputs)
+                merged.append(m)
+        return merged
+
+
+if __name__ == "__main__":
+    # Testing code below. It creates a simple model to train on MNIST and
+    # tries to run it on 2 GPUs. It saves the graph so it can be viewed
+    # in TensorBoard. Run it as:
+    #
+    # python3 parallel_model.py
+
+    import os
+    import numpy as np
+    import keras.optimizers
+    from keras.datasets import mnist
+    from keras.preprocessing.image import ImageDataGenerator
+
+    GPU_COUNT = 2
+
+    # Root directory of the project
+    ROOT_DIR = os.path.abspath("../")
+
+    # Directory to save logs and trained model
+    MODEL_DIR = os.path.join(ROOT_DIR, "logs")
+
+    def build_model(x_train, num_classes):
+        # Reset default graph. Keras leaves old ops in the graph,
+        # which are ignored for execution but clutter graph
+        # visualization in TensorBoard.
+        tf.reset_default_graph()
+
+        inputs = KL.Input(shape=x_train.shape[1:], name="input_image")
+        x = KL.Conv2D(32, (3, 3), activation='relu', padding="same",
+                      name="conv1")(inputs)
+        x = KL.Conv2D(64, (3, 3), activation='relu', padding="same",
+                      name="conv2")(x)
+        x = KL.MaxPooling2D(pool_size=(2, 2), name="pool1")(x)
+        x = KL.Flatten(name="flat1")(x)
+        x = KL.Dense(128, activation='relu', name="dense1")(x)
+        x = KL.Dense(num_classes, activation='softmax', name="dense2")(x)
+
+        return KM.Model(inputs, x, "digit_classifier_model")
+
+    # Load MNIST Data
+    (x_train, y_train), (x_test, y_test) = mnist.load_data()
+    x_train = np.expand_dims(x_train, -1).astype('float32') / 255
+    x_test = np.expand_dims(x_test, -1).astype('float32') / 255
+
+    print('x_train shape:', x_train.shape)
+    print('x_test shape:', x_test.shape)
+
+    # Build data generator and model
+    datagen = ImageDataGenerator()
+    model = build_model(x_train, 10)
+
+    # Add multi-GPU support.
+    model = ParallelModel(model, GPU_COUNT)
+
+    optimizer = keras.optimizers.SGD(lr=0.01, momentum=0.9, clipnorm=5.0)
+
+    model.compile(loss='sparse_categorical_crossentropy',
+                  optimizer=optimizer, metrics=['accuracy'])
+
+    model.summary()
+
+    # Train
+    model.fit_generator(
+        datagen.flow(x_train, y_train, batch_size=64),
+        steps_per_epoch=50, epochs=10, verbose=1,
+        validation_data=(x_test, y_test),
+        callbacks=[keras.callbacks.TensorBoard(log_dir=MODEL_DIR,
+                                               write_graph=True)]
+    )