Skip to content
Snippets Groups Projects
Commit c552ba8a authored by ssibirtsev's avatar ssibirtsev
Browse files

Update train_droplet.py

parent 731c5e4f
No related branches found
No related tags found
No related merge requests found
...@@ -29,6 +29,8 @@ cluster = False ...@@ -29,6 +29,8 @@ cluster = False
### please specify only for non-cluster executions ### please specify only for non-cluster executions
# file format of images
file_format = "jpg"
# input dataset path to find in "...\datasets\input\..." # input dataset path to find in "...\datasets\input\..."
dataset_path = r"test" dataset_path = r"test"
# path to save the new weights "...\models\... # path to save the new weights "...\models\...
...@@ -142,6 +144,7 @@ if cluster == False: ...@@ -142,6 +144,7 @@ if cluster == False:
ROOT_DIR = os.path.abspath("") ROOT_DIR = os.path.abspath("")
WEIGHTS_DIR = os.path.join(ROOT_DIR, "models", new_weights_path) WEIGHTS_DIR = os.path.join(ROOT_DIR, "models", new_weights_path)
EXCEL_DIR = os.path.join(WEIGHTS_DIR, name_result_file + '.xlsx') EXCEL_DIR = os.path.join(WEIGHTS_DIR, name_result_file + '.xlsx')
FILE_FORMAT = file_format
BASE_WEIGHTS = base_weights BASE_WEIGHTS = base_weights
IMAGE_MAX = image_max IMAGE_MAX = image_max
EARLY = early_stopping EARLY = early_stopping
...@@ -190,6 +193,8 @@ else: ...@@ -190,6 +193,8 @@ else:
parser.add_argument('--base_weights', required=False, default='coco', parser.add_argument('--base_weights', required=False, default='coco',
metavar="/path/to/weights.h5", metavar="/path/to/weights.h5",
help="Path to weights .h5 file or 'coco'") help="Path to weights .h5 file or 'coco'")
parser.add_argument('--file_format', required=True,
help='')
parser.add_argument('--masks', required=False, type=str, parser.add_argument('--masks', required=False, type=str,
default="False", default="False",
help='Generate detection masks? True = yes, False = no') help='Generate detection masks? True = yes, False = no')
...@@ -283,6 +288,7 @@ else: ...@@ -283,6 +288,7 @@ else:
save_dir = "/rwthfs/rz/cluster/hpcwork/ss002458/" save_dir = "/rwthfs/rz/cluster/hpcwork/ss002458/"
WEIGHTS_DIR = os.path.join(save_dir, "models", args.new_weights_path + timestr + str(random.randint(1000,9999))) WEIGHTS_DIR = os.path.join(save_dir, "models", args.new_weights_path + timestr + str(random.randint(1000,9999)))
EXCEL_DIR = os.path.join(WEIGHTS_DIR, args.name_result_file + '.xlsx') EXCEL_DIR = os.path.join(WEIGHTS_DIR, args.name_result_file + '.xlsx')
FILE_FORMAT = args.file_format
BASE_WEIGHTS = args.base_weights BASE_WEIGHTS = args.base_weights
# #
if args.masks == "True": if args.masks == "True":
...@@ -346,10 +352,11 @@ dataset_size = 0 ...@@ -346,10 +352,11 @@ dataset_size = 0
images_mean_pixel = [] images_mean_pixel = []
for f in sorted(os.listdir(DATASET_DIR)): for f in sorted(os.listdir(DATASET_DIR)):
sub_folder = os.path.join(DATASET_DIR, f) sub_folder = os.path.join(DATASET_DIR, f)
dataset_size = dataset_size + len(glob.glob1(sub_folder, "*.jpg")) dataset_size = dataset_size + len(glob.glob1(sub_folder, FILE_FORMAT))
images_path = glob.glob(sub_folder + "/*.jpg") images_path = glob.glob(sub_folder + FILE_FORMAT)
for img_path in images_path: for img_path in images_path:
img = cv2.imread(img_path) img = cv2.imread(img_path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
if CONTRAST ==1: if CONTRAST ==1:
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
img = exposure.equalize_adapthist(img) img = exposure.equalize_adapthist(img)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment