diff --git a/classes/droplet/train_droplet.py b/classes/droplet/train_droplet.py
new file mode 100644
index 0000000000000000000000000000000000000000..0a6b1c9a327ecbe547f7453b3a8a26d577db4278
--- /dev/null
+++ b/classes/droplet/train_droplet.py
@@ -0,0 +1,917 @@
+"""
+MRCNN Particle Detection
+Train the droplet class of the MRCNN model.
+
+The source code of "MRCNN Particle Detection" (https://git.rwth-aachen.de/avt-fvt/private/mrcnn-particle-detection) 
+is based on the source code of "Mask R-CNN" (https://github.com/matterport/Mask_RCNN).
+
+The source code of "Mask R-CNN" is licensed under the MIT License (MIT).
+Copyright (c) 2017 Matterport, Inc.
+Written by Waleed Abdulla
+
+All source code modifications to the source code of "Mask R-CNN" in "MRCNN Particle Detection" 
+are licensed under the Eclipse Public License v2.0 (EPL 2.0).
+Copyright (c) 2022-2023 Fluid Process Engineering (AVT.FVT), RWTH Aachen University
+Edited by Stepan Sibirtsev, Mathias Neufang & Jakob Seiler
+
+The coyprights and license terms are given in LICENSE.
+
+Ideas and a small code snippets were adapted from these sources:
+https://github.com/mat02/Mask_RCNN
+"""  
+
+### ------------------------- ###
+### input training parameters ###
+### ------------------------- ###
+
+# is the script executed on the cluster? 
+# e.g., RWTH High Performance Computing cluster? 
+# True=yes, False=no
+cluster=False
+
+### please specify only for non-cluster executions 
+
+# file format of images
+file_format="jpg"
+# input dataset folder located in path: "...\datasets\input\..."
+dataset_path="test_input"                  
+# output weights folder located in path: "...\models\..."
+# weights of the individual epochs=MRCNN models
+new_weights_path="weights"
+# name of the excel output file located in path: "...\models\<WeightsFolderName>\"
+# this excel file shows the distribution of the input images to the folds.
+name_result_file="folds"
+# generate detection masks? 
+# True=yes, False=no
+masks=False
+# use GPU or CPU? 
+# True=GPU, False=CPU
+device=True
+# epochs to train
+epochs=50
+# should early stopping be used? 
+# 0=no, otherwise value is number of epochs without improvement
+early_stopping=0
+# loss monitored by early stopping
+early_loss="val_loss"
+# base weights the training starts from
+base_weights="coco"
+# percentage of the training dataset to be used for training [%], 
+# e.g., to determine minimum required number of images in training/validation set 
+# for accurate detection performance
+dataset_quantity=100
+
+### specifications for Weights & Biases
+
+# use Weights & Biases to track training data? 
+# True=yes, False=no
+use_wandb=False
+# enter W&B entity name 
+# (check projects in Weights & Biases, entity name is next to the project name)
+wandb_entity="test"
+# enter W&B project name
+wandb_project="test"
+# enter group name within the W&B project
+wandb_group="test"
+# enter run name within the group of W&B project
+wandb_run="test"
+
+### specifications for k-fold cross-validation
+
+# perform a k-fold cross-validation? 
+# if you want to train the final models, choose False.
+# True=yes, False=no
+cross_validation=True
+# number of folds for k-fold cross-validation
+k_fold=5
+# validation fold. starting with 0. the remaining folds are training folds
+k_fold_val=0
+
+### specifications for training parameters
+
+# backbone (see BACKBONE parameter in config.py). 
+# 0="resnet50", 1="resnet101"
+backbone_type=0
+# which layers should be trained? 
+# True=train all layers, False=train only heads
+train_all_layers=True
+# number of images used to train the model on each GPU.
+# if only one GPU is used, this parameter is equivalent to batch size 
+# (see BATCH_SIZE parameter in config.py).
+# a 12GB GPU can typically handle 2 images of 1024x1024px resolution.
+# adjust this parameter based on your GPU memory and image resolution. 
+images_gpu=1
+# learning rate (see LEARNING_RATE parameter in config.py). 
+# 0=0.01, 1=0.001, 2=0.0001
+learning=1
+# image resolution (see IMAGE_MAX_DIM parameter in config.py). 
+# select the closest value corresponding to the largest side of the image.
+# 0=512, 1=1024, 2=2048
+image_max=1
+# learning momentum (see LEARNING_MOMENTUM parameter in config.py). 
+# 0=0.8, 1=0.9, 2=0.99
+momentum=1
+# weight decay (see WEIGHT_DECAY parameter in config.py). 
+# 0=0.0001, 1=0.001, 2=0.01
+w_decay=0
+
+### specifications for augmentation methods
+
+# use augmentation methods? 
+# True=yes, False=no
+augmentation=False
+# use augmentation method flip? 
+# 0=no, 1=(0.5, 0.5)
+flip=0
+# use augmentation method crop? 
+# 0=no, 1=(-0.25, 0), 2=(-0.1, 0)
+cropandpad=0
+# use augmentation method rotate? 
+# 0=no, 1=(-45, 45), 2=(-90, 90)
+rotate=0
+# use augmentation method additive Gaussian noise? 
+# 0=no, 1=0.01, 2=0.02
+noise=0 
+# use augmentation method gamma contrast? 
+# 0=no, 1=yes
+gamma=0 
+
+### specifications for contrast adjustment
+
+# use contrast adjustment? 
+# 0=no, 1=contrast limited adaptive histogram equalization, 2=contrast stretching  
+contrast=0
+
+### ----------------------------------- ###
+###             Initialization          ###
+### ----------------------------------- ###
+
+import warnings
+warnings.simplefilter(action='ignore', category=FutureWarning)
+
+import os
+import sys
+import json
+import datetime
+import time
+import numpy as np
+import skimage.draw
+import tensorflow as tf
+import random
+import pandas as pd
+from numpy import array
+from numpy import asarray
+from pathlib import Path
+from skimage import exposure
+import cv2
+import glob
+
+# Root directory of the project
+if cluster == False:
+    ROOT_DIR = os.path.abspath("")
+    WEIGHTS_DIR = os.path.join(ROOT_DIR, "models", new_weights_path)
+    EXCEL_DIR = os.path.join(WEIGHTS_DIR, name_result_file + '.xlsx')
+    FILE_FORMAT = file_format
+    BASE_WEIGHTS = base_weights
+    IMAGE_MAX = image_max
+    EARLY = early_stopping
+    EARLY_LOSS = early_loss
+    EPOCH_NUMBER = epochs
+    DATASET_QUANTITY = dataset_quantity
+    K_FOLD = k_fold
+    K_FOLD_VAL = k_fold_val
+    DEVICE = device
+    IMAGES_GPU = images_gpu
+    MASKS = masks
+    AUGMENTATION = augmentation
+    CONTRAST = contrast
+    USE_WANDB = use_wandb
+    BACKBONE_TYPE = backbone_type
+    CROSS_VALIDATION = cross_validation
+    WANDB_ENTITY = wandb_entity
+    WANDB_PROJECT = wandb_project
+    WANDB_GROUP = wandb_group
+    WANDB_RUN = wandb_run
+    LEARNING = learning
+    MOMENTUM = momentum
+    W_DECAY = w_decay
+    TRAIN_ALL_LAYERS = train_all_layers
+    AUG_PARAMETERS = (cropandpad, rotate, noise, gamma, flip)
+    if CONTRAST == 0:
+        DATASET_DIR = os.path.join(ROOT_DIR, "datasets/input", dataset_path, "original")
+    elif CONTRAST == 1:
+        DATASET_DIR = os.path.join(ROOT_DIR, "datasets/input", dataset_path, "clahe")
+    elif CONTRAST == 2:
+        DATASET_DIR = os.path.join(ROOT_DIR, "datasets/input", dataset_path, "stretching")
+
+else:
+    import argparse
+    # Parse command line arguments
+    parser = argparse.ArgumentParser(
+        description='Train Mask R-CNN to detect droplets.')
+    parser.add_argument('--dataset_path', required=True, 
+                        help='Directory of the Droplet dataset')
+    parser.add_argument('--name_result_file', required=False, default='folds',
+                        metavar="/path/to/droplet/dataset/",
+                        help='Name of the excel result file to find in "Mask_R_CNN\models\<WeightsFolderName>\"')    
+    parser.add_argument('--new_weights_path', required=False, default='weights',
+                        metavar="/path/to/logs/",
+                        help='Logs and checkpoints directory (default=logs/)')
+    parser.add_argument('--base_weights', required=False, default='coco',
+                        metavar="/path/to/weights.h5",
+                        help="Path to weights .h5 file or 'coco'")
+    parser.add_argument('--file_format', required=True,
+                        help='')
+    parser.add_argument('--masks', required=False, type=str,
+                        default="False",
+                        help='Generate detection masks? True = yes, False = no')
+    parser.add_argument('--device', required=True, type=str,
+                        help='is the evaluation done on GPU? True = yes, False = no')
+    parser.add_argument('--augmentation', required=False, type=str,
+                        default="False",
+                        help='image augmentation of dataset')
+    parser.add_argument('--use_wandb', required=False, type=str,
+                        default="False",
+                        help='use wandb for data collection')
+    parser.add_argument('--cross_validation', required=False, type=str,
+                        default="False",
+                        help='trains model on all data, disables train/validation split')
+    parser.add_argument('--train_all_layers', required=False, type=str,
+                        default="True",
+                        help='')    
+    #
+    parser.add_argument('--early_loss', required=False,
+                        default="val_loss",
+                        help='monitored early stopping quantity')
+    parser.add_argument('--wandb_entity', required=False,
+                        default="test_entity",
+                        help='')    
+    parser.add_argument('--wandb_project', required=False,
+                        default="test_project",
+                        help='')
+    parser.add_argument('--wandb_group', required=False,
+                        default="test_group",
+                        help='')
+    parser.add_argument('--wandb_run', required=False,
+                        default="test_run",
+                        help='')
+    #
+    parser.add_argument('--image_max', required=True, type=int,
+                        help="max. image size")                      
+    parser.add_argument('--images_gpu', required=True, type=int,
+                        help='Number of images to train with on each GPU')
+    parser.add_argument('--early_stopping', required=False, type=int,
+                        default=0,
+                        help='enables early stopping')
+    parser.add_argument('--epochs', required=False, type=int,
+                        default=50,
+                        help='set number of training epochs, default = 15')
+    parser.add_argument('--dataset_quantity', required=False, type=int,
+                        default=100,
+                        help='ratio of train/validation dataset in [%], default = 100')
+    parser.add_argument('--k_fold', required=False, type=int,
+                        default=5,
+                        help='# number of folds for k-fold cross validation')       
+    parser.add_argument('--k_fold_val', required=False, type=int,
+                        default=0,
+                        help='fold of k fold validation set')            
+    parser.add_argument('--contrast', required=False, type=int,
+                        default=0,
+                        help='Contrast adjustment? 0 = no, 1 = contrast limited adaptive histogramm equalization, 2 = contrast stretching ')
+    parser.add_argument('--backbone_type', required=False, type=int,
+                        default=0,
+                        help='"resnet101" or "resnet50"')
+    parser.add_argument('--learning', required=False, type=int,
+                        default=1,
+                        help='')
+    parser.add_argument('--momentum', required=False, type=int,
+                        default=1,
+                        help='')
+    parser.add_argument('--w_decay', required=False, type=int,
+                        default=0,
+                        help='')
+                    
+    # Augmentations
+    parser.add_argument('--flip', required=False, default="0")
+    parser.add_argument('--cropandpad', required=False, default="0")
+    parser.add_argument('--rotate', required=False, default="0")
+    parser.add_argument('--noise', required=False, default="0")
+    parser.add_argument('--gamma', required=False, default="0")
+    
+    args = parser.parse_args()
+
+    timestr = time.strftime("%H")   
+    ROOT_DIR = os.path.join("/rwthfs/rz/cluster", os.path.abspath("../.."))
+    save_dir = "/rwthfs/rz/cluster/hpcwork/ss002458/"
+    WEIGHTS_DIR = os.path.join(save_dir, "models", args.new_weights_path + timestr + str(random.randint(1000,9999)))
+    EXCEL_DIR = os.path.join(WEIGHTS_DIR, args.name_result_file + '.xlsx')    
+    FILE_FORMAT = args.file_format    
+    BASE_WEIGHTS = args.base_weights
+    #
+    if args.masks == "True":
+        MASKS = True
+    elif args.masks == "False":
+        MASKS = False
+    if args.device == "True":
+        DEVICE = True
+    elif args.device == "False":
+        DEVICE = False
+    if args.augmentation == "True":
+        AUGMENTATION = True
+    elif args.augmentation == "False":
+        AUGMENTATION = False
+    if args.use_wandb == "True":
+        USE_WANDB = True
+    elif args.use_wandb == "False":
+        USE_WANDB = False
+    if args.cross_validation == "True":
+        CROSS_VALIDATION = True
+    elif args.cross_validation == "False":
+        CROSS_VALIDATION = False
+    if args.train_all_layers == "True":
+        TRAIN_ALL_LAYERS = True
+    elif args.train_all_layers == "False":
+        TRAIN_ALL_LAYERS = False       
+    #
+    IMAGE_MAX = args.image_max
+    EARLY = args.early_stopping
+    EARLY_LOSS = args.early_loss
+    EPOCH_NUMBER = args.epochs
+    DATASET_QUANTITY = args.dataset_quantity
+    K_FOLD = args.k_fold
+    K_FOLD_VAL = args.k_fold_val
+    DEVICE = args.device
+    IMAGES_GPU = args.images_gpu
+    MASKS = args.masks
+    AUGMENTATION = args.augmentation
+    CONTRAST = args.contrast
+    AUG_PARAMETERS = (args.cropandpad, args.rotate, args.noise, args.gamma, args.flip)
+    CROSS_VALIDATION = args.cross_validation
+    WANDB_ENTITY = args.wandb_entity
+    WANDB_PROJECT = args.wandb_project
+    WANDB_GROUP = args.wandb_group
+    WANDB_RUN = args.wandb_run
+    if CONTRAST == 0:
+        DATASET_DIR = os.path.join(ROOT_DIR, "datasets/input", args.dataset_path, "original")
+    elif CONTRAST == 1:
+        DATASET_DIR = os.path.join(ROOT_DIR, "datasets/input", args.dataset_path, "clahe")
+    elif CONTRAST == 2:
+        DATASET_DIR = os.path.join(ROOT_DIR, "datasets/input", args.dataset_path, "stretching")
+    USE_WANDB = args.use_wandb
+    BACKBONE_TYPE = args.backbone_type
+    LEARNING = args.learning
+    MOMENTUM = args.momentum
+    W_DECAY = args.w_decay
+
+
+### Initialization, wie viele Bilder insgesamt (training + validierung)
+dataset_size = 0
+images_mean_pixel = []
+for f in sorted(os.listdir(DATASET_DIR)): 
+    sub_folder = os.path.join(DATASET_DIR, f)
+    dataset_size = dataset_size + len(glob.glob1(sub_folder, "*." + FILE_FORMAT))
+    images_path = glob.glob(sub_folder + "/*." + FILE_FORMAT)
+    for img_path in images_path:
+        img = cv2.imread(img_path)
+        if CONTRAST ==1:
+            img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
+            img = exposure.equalize_adapthist(img)
+            img = img.astype('float32') * 255   
+        images_mean_pixel.append(img)
+color_sum=[0,0,0]
+for img in images_mean_pixel:
+    pixels = asarray(img)
+    pixels = pixels.astype('float32')
+    # calculate per-channel means and standard deviations
+    means = pixels.mean(axis=(0, 1), dtype='float64')
+    color_sum += means
+    mean_pixel = color_sum/len(images_mean_pixel)
+
+COMMAND_MODE = "train"
+# Import Mask RCNN
+sys.path.append(ROOT_DIR)  # To find local version of the library
+Path(WEIGHTS_DIR).mkdir(parents=True, exist_ok=True)
+from mrcnn.config import Config
+from mrcnn import model as modellib, utils
+
+# Path to trained weights file
+COCO_WEIGHTS_PATH = os.path.join(ROOT_DIR, "models/coco/mask_rcnn_coco.h5")
+
+# Directory to save logs and model checkpoints, if not provided
+# through the command line argument --logs
+# DEFAULT_LOGS_DIR = os.path.join(r'D:\logs', "models")
+
+############################################################
+#  Configurations
+############################################################
+
+
+class DropletConfig(Config):
+    """Configuration for training on the toy  dataset.
+    Derives from the base Config class and overrides some values.
+    """
+    # Give the configuration a recognizable name
+    NAME = "droplet"
+
+    # NUMBER OF GPUs to use. When using only a CPU, this needs to be set to 1.
+    GPU_COUNT = 1
+
+    # Generate detection masks
+    #     False: Output only bounding boxes like in Faster-RCNN
+    #     True: Generate masks as in Mask-RCNN
+    if MASKS == True:
+        GENERATE_MASKS = True
+    else: 
+        GENERATE_MASKS = False
+
+    # We use a GPU with 12GB memory, which can fit two images.
+    # Adjust down if you use a smaller GPU.
+    if DEVICE == True:
+        IMAGES_PER_GPU = IMAGES_GPU
+    else:
+        IMAGES_PER_GPU = 1
+
+    #
+    if BACKBONE_TYPE == 0:
+        BACKBONE = "resnet50"
+    else:   
+        BACKBONE = "resnet101"
+    #
+    if LEARNING == 0:
+        LEARNING_RATE = 0.01
+    elif LEARNING == 1:
+        LEARNING_RATE = 0.001
+    elif LEARNING == 2:
+        LEARNING_RATE = 0.0001        
+    #
+    if MOMENTUM == 0:
+        LEARNING_MOMENTUM = 0.8
+    elif MOMENTUM == 1:
+        LEARNING_MOMENTUM = 0.9
+    elif MOMENTUM == 2:
+        LEARNING_MOMENTUM = 0.99
+    #
+    if W_DECAY == 0:
+        WEIGHT_DECAY = 0.0001
+    elif W_DECAY == 1:
+        WEIGHT_DECAY = 0.001
+    elif W_DECAY == 2:
+        WEIGHT_DECAY = 0.01       
+
+    # Number of classes (including background)
+    NUM_CLASSES = 1 + 1  # Background + droplet
+
+    # Number of training steps per epoch
+    if CROSS_VALIDATION == True:
+        TRAINING_STEPS = round((DATASET_QUANTITY//100)*dataset_size*(K_FOLD-1)/K_FOLD)//(IMAGES_GPU*GPU_COUNT)
+        VALIDATION_STEPS = round((DATASET_QUANTITY//100)*dataset_size/K_FOLD)//(IMAGES_GPU*GPU_COUNT)
+    else:
+        TRAINING_STEPS = round((DATASET_QUANTITY//100)*dataset_size)//(IMAGES_GPU*GPU_COUNT)
+
+    # Input image resizing
+    if IMAGE_MAX == 0:
+        IMAGE_MAX_DIM = 512
+    elif IMAGE_MAX == 1:
+        IMAGE_MAX_DIM = 1024
+    elif IMAGE_MAX == 2:
+        IMAGE_MAX_DIM = 2048
+
+    IMAGE_MIN_DIM = IMAGE_MAX_DIM
+
+    MEAN_PIXEL = mean_pixel
+
+    # unterschiedlich
+    #CONTRAST = CONTRAST
+    #if CONTRAST == 0:
+    #    MEAN_PIXEL = np.array([120.4, 120.4, 120.4])
+    #elif CONTRAST == 1:
+    #    MEAN_PIXEL = np.array([112.3, 112.3, 112.3])
+    #elif CONTRAST == 2:
+    #    MEAN_PIXEL = np.array([148.9, 148.9, 148.9])
+
+
+############################################################
+#  Dataset
+############################################################
+
+class DropletDataset(utils.Dataset):
+
+    def load_droplet(self, dataset_dir, subset, model):
+        """Load a subset of the Droplet dataset.
+        dataset_dir: Root directory of the dataset.
+        subset: Subset to load: train or val
+        """
+        # Add classes. We have only one class to add.
+        self.add_class("droplet", 1, "droplet")
+
+        # define path of training/validation dataset
+        # dataset_dir = os.path.join(dataset_dir, "all")
+
+        # Load annotations
+        # VGG Image Annotator (up to version 1.6) saves each image in the form:
+        # { 'filename': '28503151_5b5b7ec140_b.jpg',
+        #   'regions': {
+        #       '0': {
+        #           'region_attributes': {},
+        #           'shape_attributes': {
+        #               'all_points_x': [...],
+        #               'all_points_y': [...],
+        #               'name': 'polygon'}},
+        #       ... more regions ...
+        #   },
+        #   'size': 100202
+        # }
+        # We mostly care about the x and y coordinates of each region
+        # Note: In VIA 2.0, regions was changed from a dict to a list.
+        annotations_all = []
+        annotations = []
+        for dataset_folder in sorted(os.listdir(dataset_dir)):
+            annotations_quality = json.load(open(os.path.join(dataset_dir, dataset_folder, "train.json")))
+            annotations_quality = list(annotations_quality.values()) # don't need the dict keys
+            # The VIA tool saves images in the JSON even if they don't have any
+            # annotations. Skip unannotated images.
+            annotations_quality = [a for a in annotations_quality if a['regions']]
+            if CROSS_VALIDATION == True:
+                ### random choice of the training/validation dataset from the existing dataset
+                # resetting the random seed to ensure comparability between runs
+                np.random.seed(23)
+                # define quantity of train/validation dataset
+                train_val_set = int(round(DATASET_QUANTITY*len(annotations_quality)/100))
+                # random choice of the training/validation dataset
+                annotations_quality = np.random.choice(annotations_quality, train_val_set, replace=False)
+                # split training/validation dataset in folds
+                annotations_quality = np.array_split(annotations_quality,K_FOLD)
+                # transponse list for further processing
+                annotations_quality = np.transpose(annotations_quality)
+                # merging the datasets of different qualities into one dataset
+                annotations_all.extend(annotations_quality)
+                # save the k-fold splitted training/validation dataset
+                pd.DataFrame(annotations_all).to_excel(EXCEL_DIR, header=True, index=False)
+                # go through columns of the k-fold splitted training/validation dataset
+                for column in range(K_FOLD): 
+                    annotations = [row[column] for row in annotations_quality]   
+                    # check if partial dataset is a train or validation dataset
+                    if subset == "train" and column != K_FOLD_VAL:
+                        annotations_use = annotations
+                    elif subset == "val" and column == K_FOLD_VAL:
+                        annotations_use = annotations
+                    else:
+                        continue
+                    # Add images
+                    for a in annotations_use:
+                        # Get the x, y coordinates of points of the polygons that make up
+                        # the outline of each object instance. These are stores in the
+                        # shape_attributes (see json format above)
+                        # The if condition is needed to support VIA versions 1.x and 2.x.
+                        if type(a['regions']) is dict:
+                            polygons = [r['shape_attributes'] for r in a['regions'].values()]
+                        else:
+                            polygons = [r['shape_attributes'] for r in a['regions']] 
+
+                        # load_mask() needs the image size to convert polygons to masks.
+                        # Unfortunately, VIA doesn't include it in JSON, so we must read
+                        # the image. This is only managable since the dataset is tiny.
+                        image_path = os.path.join(dataset_dir, dataset_folder, a['filename'])
+                        image = skimage.io.imread(image_path)  
+                        height, width = image.shape[:2]
+
+                        if type(a['regions']) is dict:
+                            polygons = [r['shape_attributes'] for r in a['regions'].values()]
+                        else:
+                            polygons = [r['shape_attributes'] for r in a['regions']] 
+                        
+                        self.add_image(
+                            "droplet",
+                            image_id=a['filename'],  # use file name as a unique image id
+                            path=image_path,
+                            width=width, height=height,
+                            polygons=polygons)
+            else:
+                ### random choice of the training/validation dataset from the existing dataset
+                # resetting the random seed to ensure comparability between runs
+                np.random.seed(23)
+                # define quantity of train/validation dataset
+                train_val_set = int(round(DATASET_QUANTITY*len(annotations_quality)/100))
+                # random choice of the training/validation dataset
+                annotations_quality = np.random.choice(annotations_quality, train_val_set, replace=False)
+                if subset == "train":
+                    for a in annotations_quality:
+                        image_path = os.path.join(dataset_dir, dataset_folder, a['filename'])
+                        image = skimage.io.imread(image_path)
+                        height, width = image.shape[:2]
+                        
+                        if type(a['regions']) is dict:
+                            polygons = [r['shape_attributes'] for r in a['regions'].values()]
+                        else:
+                            polygons = [r['shape_attributes'] for r in a['regions']] 
+                        
+                        self.add_image(
+                            "droplet",
+                            # use file name as a unique image id
+                            image_id=a['filename'],
+                            path=image_path,
+                            width=width, height=height,
+                            polygons=polygons)
+
+    def load_mask(self, image_id):
+        """Generate instance masks for an image.
+       Returns:
+        masks: A bool array of shape [height, width, instance count] with
+            one mask per instance.
+        class_ids: a 1D array of class IDs of the instance masks.
+        """
+        # If not a droplet dataset image, delegate to parent class.
+        image_info = self.image_info[image_id]
+        if image_info["source"] != "droplet":
+            return super(self.__class__, self).load_mask(image_id)
+
+        # Convert polygons to a bitmap mask of shape
+        # [height, width, instance_count]
+        info = self.image_info[image_id]
+        mask = np.zeros([info["height"], info["width"], len(info["polygons"])],
+                        dtype=np.uint8)
+       
+        for i, p in enumerate(info["polygons"]):
+            # Get indexes of pixels inside the polygon and set them to 1
+            if p['name']=='polygon':
+                rr, cc = skimage.draw.polygon(p['all_points_y'], p['all_points_x'])
+            
+            elif p['name']=='ellipse':
+                rr, cc = skimage.draw.ellipse(p['cy'], p['cx'], p['ry'], p['rx'])
+
+            else:
+                
+                rr, cc = skimage.draw.circle(p['cy'], p['cx'], p['r'])
+            
+            x = np.array((rr, cc)).T
+            d = np.array([i for i in x if (i[0] < info["height"] and i[0] > 0)])
+            e = np.array([i for i in d if (i[1] < info["width"] and i[1] > 0)])
+
+            rr = np.array([u[0] for u in e])
+            cc = np.array([u[1] for u in e])
+
+            if len(rr)==0 or len(cc)==0:
+                continue
+            mask[rr, cc, i] = 1    
+
+        # Return mask, and array of class IDs of each instance. Since we have
+        # one class ID only, we return an array of 1s
+        return mask.astype(np.bool), np.ones([mask.shape[-1]], dtype=np.int32)
+
+    def image_reference(self, image_id):
+        """Return the path of the image."""
+        info = self.image_info[image_id]
+        if info["source"] == "droplet":
+            return info["path"]
+        else:
+            super(self.__class__, self).image_reference(image_id)
+
+def get_augmentation():
+    import imgaug.augmenters as iaa
+    
+    # tables which translate input to values for evaluation
+    # simpler to pass 0-2 in array jobs than specific values
+    
+    cropDict = {
+        1: (-0.25, 0),
+        2: (-0.1, 0)
+    }
+    rotDict = {
+        1: (-45, 45),
+        2: (-90, 90)
+    }
+    noiseDict = {
+        1: 0.01,
+        2: 0.02
+    }
+    flipDict = {
+        1: (0.5, 0.5)
+    }
+    
+    if AUG_PARAMETERS[4] != 0:
+        flip_lr_up = flipDict[AUG_PARAMETERS[4]]
+        aug = iaa.Sequential([
+                iaa.Fliplr(flip_lr_up[0]),
+                iaa.Flipud(flip_lr_up[1])
+        # iaa.Sometimes(0.5, iaa.Rot90(1))
+        ])
+    else:
+        aug = iaa.Sequential([
+        ])
+    # add other augments that were enabled
+    if AUG_PARAMETERS[1] != 0:
+        rot = rotDict[AUG_PARAMETERS[1]]
+        randrot = iaa.Affine(rotate=(rot[0], rot[1]))
+        aug = iaa.Sequential([aug, randrot])
+
+    if AUG_PARAMETERS[0] != 0:
+        crop = cropDict[AUG_PARAMETERS[0]]
+        randcrop = iaa.CropAndPad(percent=(crop[0], crop[1]), sample_independently=False)
+        aug = iaa.Sequential([aug, randcrop])
+
+    if AUG_PARAMETERS[3] != 0:
+        #gamma = iaa.Sometimes(0.25*AUG_PARAMETERS[4], iaa.GammaContrast(gamma=(0.5, 2)))
+        gamma = AUG_PARAMETERS[3], iaa.GammaContrast(gamma=(0.5, 2))
+        aug = iaa.Sequential([aug, gamma])
+
+    if AUG_PARAMETERS[2] != 0:
+        noise = noiseDict[AUG_PARAMETERS[2]]
+        gaussnoise = iaa.AdditiveGaussianNoise(scale=noise*255)
+        aug = iaa.Sequential([aug, gaussnoise])
+    return aug
+
+def train(model, custom_callbacks=None):
+    """Train the model."""
+
+    # Training dataset.
+    dataset_train = DropletDataset()
+    dataset_train.load_droplet(DATASET_DIR, "train", model)
+    dataset_train.prepare()
+
+    # Validation dataset
+    if CROSS_VALIDATION == True:  
+        dataset_val = DropletDataset()
+        dataset_val.load_droplet(DATASET_DIR, "val", model)
+        dataset_val.prepare()
+    else:
+        dataset_val = None
+    print(f"Train Dataset: {len(dataset_train.image_ids)} Pictures")
+
+    # define augmentation
+    if AUGMENTATION == True:
+        augmentation_type = get_augmentation()
+    else: 
+        augmentation_type = None
+
+    # *** This training schedule is an example. Update to your needs ***
+    # Since we're using a very small dataset, and starting from
+    # COCO trained weights, we don't need to train too long. Also,
+    # no need to train all layers, just the heads should do it.
+    print("Training network layers")
+    if TRAIN_ALL_LAYERS == True:
+        layers = 'all'
+    else:
+        layers = 'heads'
+    model.train(dataset_train, dataset_val,
+                learning_rate=config.LEARNING_RATE,
+                epochs=EPOCH_NUMBER, augmentation=augmentation_type,
+                layers=layers, custom_callbacks=custom_callbacks)
+
+def color_splash(image, mask):
+    """Apply color splash effect.
+    image: RGB image [height, width, 3]
+    mask: instance segmentation mask [height, width, instance count]
+
+    Returns result image.
+    """
+    # Make a grayscale copy of the image. The grayscale copy still
+    # has 3 RGB channels, though.
+    gray = skimage.color.gray2rgb(skimage.color.rgb2gray(image)) * 255
+    # Copy color pixels from the original color image where mask is set
+    if mask.shape[-1] > 0:
+        # We're treating all instances as one, so collapse the mask into one layer
+        mask = (np.sum(mask, -1, keepdims=True) >= 1)
+        splash = np.where(mask, image, gray).astype(np.uint8)
+    else:
+        splash = gray.astype(np.uint8)
+    return splash
+
+
+def detect_and_color_splash(model, image_path=None, video_path=None):
+    assert image_path or video_path
+
+    # Image or video?
+    if image_path:
+        # Run model detection and generate the color splash effect
+        print("Running on {}".format(args.image))
+        # Read image
+        image = skimage.io.imread(args.image)
+        # Detect objects
+        r = model.detect([image], verbose=1)[0]
+        # Color splash
+        splash = color_splash(image, r['masks'])
+        # Save output
+        file_name = "splash_{:%Y%m%dT%H%M%S}.png".format(datetime.datetime.now())
+        skimage.io.imsave(file_name, splash)
+    elif video_path:
+        import cv2
+        # Video capture
+        vcapture = cv2.VideoCapture(video_path)
+        width = int(vcapture.get(cv2.CAP_PROP_FRAME_WIDTH))
+        height = int(vcapture.get(cv2.CAP_PROP_FRAME_HEIGHT))
+        fps = vcapture.get(cv2.CAP_PROP_FPS)
+
+        # Define codec and create video writer
+        file_name = "splash_{:%Y%m%dT%H%M%S}.avi".format(datetime.datetime.now())
+        vwriter = cv2.VideoWriter(file_name,
+                                  cv2.VideoWriter_fourcc(*'MJPG'),
+                                  fps, (width, height))
+
+        count = 0
+        success = True
+        while success:
+            print("frame: ", count)
+            # Read next image
+            success, image = vcapture.read()
+            if success:
+                # OpenCV returns images as BGR, convert to RGB
+                image = image[..., ::-1]
+                # Detect objects
+                r = model.detect([image], verbose=0)[0]
+                # Color splash
+                splash = color_splash(image, r['masks'])
+                # RGB -> BGR to save image to video
+                splash = splash[..., ::-1]
+                # Add image to video writer
+                vwriter.write(splash)
+                count += 1
+        vwriter.release()
+    print("Saved to ", file_name)
+
+
+############################################################
+#  Training
+############################################################
+
+if __name__ == '__main__':
+    
+    # Validate arguments
+    if COMMAND_MODE == "train":
+        assert DATASET_DIR, "Argument --dataset is required for training"
+    elif COMMAND_MODE == "splash":
+        assert args.image or args.video,\
+               "Provide --image or --video to apply color splash"
+
+    print("Weights: ", BASE_WEIGHTS)
+    print("Dataset: ", DATASET_DIR)
+    print("Logs: ", WEIGHTS_DIR)
+
+    # Configurations
+    if COMMAND_MODE == "train":
+        config = DropletConfig()
+    else:
+        class InferenceConfig(DropletConfig):
+            # Set batch size to 1 since we'll be running inference on
+            # one image at a time. Batch size = GPU_COUNT * IMAGES_PER_GPU
+            GPU_COUNT = 1
+            IMAGES_PER_GPU = 1
+        config = InferenceConfig()
+    config.display()
+
+    if USE_WANDB:
+        import wandb
+        from wandb.keras import WandbCallback
+        wandb.init(project=WANDB_PROJECT,
+                   entity=WANDB_ENTITY, config=config, group=WANDB_GROUP, name=WANDB_RUN)
+
+    # Create model
+    if COMMAND_MODE == "train":
+        model = modellib.MaskRCNN(mode="training", config=config,
+                                  model_dir=WEIGHTS_DIR, k_fold_val=K_FOLD_VAL)
+    else:
+        model = modellib.MaskRCNN(mode="inference", config=config,
+                                  model_dir=WEIGHTS_DIR, k_fold_val=K_FOLD_VAL)
+
+    # Select weights file to load
+    if BASE_WEIGHTS.lower() == "coco":
+        weights_path = COCO_WEIGHTS_PATH
+        # Download weights file
+        if not os.path.exists(weights_path):
+            utils.download_trained_weights(weights_path)
+    elif BASE_WEIGHTS.lower() == "last":
+        # Find last trained weights
+        weights_path = model.find_last()
+    elif BASE_WEIGHTS.lower() == "imagenet":
+        # Start from ImageNet trained weights
+        weights_path = model.get_imagenet_weights()
+    else:
+        weights_path = os.path.join(ROOT_DIR, "models", BASE_WEIGHTS + '.h5')
+
+    # Load weights
+    print("Loading weights ", weights_path)
+    if BASE_WEIGHTS.lower() == "coco":
+        # Exclude the last layers because they require a matching
+        # number of classes
+        model.load_weights(weights_path, by_name=True, exclude=[
+            "mrcnn_class_logits", "mrcnn_bbox_fc",
+            "mrcnn_bbox", "mrcnn_mask"])
+    else:
+        model.load_weights(weights_path, by_name=True)
+
+    custom_callbacks = []
+    
+    if USE_WANDB:
+        custom_callbacks.append(WandbCallback())
+
+    if EARLY:
+        custom_callbacks.append(tf.keras.callbacks.EarlyStopping(monitor=EARLY_LOSS, patience=EARLY))
+        
+    # Train or evaluate
+    if COMMAND_MODE == "train":
+        train(model, custom_callbacks=custom_callbacks)
+    elif COMMAND_MODE == "splash":
+        detect_and_color_splash(model, image_path=args.image,
+                                video_path=args.video)
+    else:
+        print("'{}' is not recognized. "
+              "Use 'train' or 'splash'".format(COMMAND_MODE))