Skip to content
Snippets Groups Projects
check_user_input.py 9.61 KiB
Newer Older
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import os
import pandas as pd

from settings import *
from utilities.initialise_log import save_log


class check_general_settings():
    
    def __init__(self):

        if training_dataset or map_generation:
            if os.path.isdir(path_train):
                save_path = path_train + '/check_user_input.log'
                save_path = os.path.dirname(path_train) + '/check_user_input.log'
        elif prediction_dataset:
            if os.path.isdir(path_pred):
                save_path = path_pred + '/check_user_input.log'
                save_path = os.path.dirname(path_pred) + '/check_user_input.log'
            save_path = '/check_user_input.log'

        if os.path.exists(save_path):
            os.remove(save_path)

        self.logger = save_log(save_path)
        self.logger.info("Start checking user input")

        self.error = False

        self.set_up_dic()
        self.check_bools()
        self.check_list()
        self.check_int()
        self.check_int_float()
        self.check_string()
        self.check_path()
        self.check_bb()

        self.check_if_right_params_are_set()
        self.check_extension()
        self.check_path_extension_geosummary()

        for handler in self.logger.handlers:
            handler.close()
            self.logger.removeHandler(handler)

    def check_if_right_params_are_set(self):

        if training_dataset is None and prediction_dataset is None and map_generation is None:
            self.logger.error('Specify a purpose of the run! Set either training_dataset, prediction_dataset and/or map_generation')
            self.error = True

        if None in [crs, no_value, random_seed, resolution]:
            self.logger.error('Set the general settings crs, no_value, random_seed and resolution!')
            self.error = True

        if training_dataset:
            if train_from_scratch is None and train_delete is None:
                self.logger.error('Speciy whether you want to generate training dataset from scratch or add/remove feature(s)')
                self.error = True
            else:
                if None in [preprocessing, data_summary_path, key_to_include_path, path_train, path_landslide_database, ID, landslide_database_x, landslide_database_y, path_nonls_locations, num_nonls, nonls_database_x, nonls_database_y]:
                    self.logger.error('Speciy all necessary parameters for training dataset generation!')
                    self.error = True
    
        if prediction_dataset:
            if pred_from_scratch is None and pred_delete is None:
                self.logger.error('Speciy whether you want to generate prediction dataset from scratch or add/remove feature(s)')
                self.error = True
            else:
                if None in [data_summary_path, key_to_include_path, bounding_box, path_pred]:
                    self.logger.error('Speciy all necessary parameters for prediction dataset generation!')
                    self.error = True
       
        if map_generation:
            if None in [path_ml, size, not_included_pred_data, not_included_train_data, num_trees, criterion, depth, model_to_save, model_to_load, model_database_dir, parallel]:
                self.logger.error('Speciy all necessary parameters for map generation!')
                self.error = True

    def set_up_dic(self):

        self.dic = {}
        self.dic['bool'] = [training_dataset, train_from_scratch, train_delete, prediction_dataset, pred_from_scratch, pred_delete, map_generation, parallel]
        self.dic['path'] = [path_ml, data_summary_path, key_to_include_path, path_train, path_landslide_database, path_nonls_locations, path_pred, model_database_dir]
        self.dic['str'] = [crs, ID, landslide_database_x, landslide_database_y, nonls_database_x, nonls_database_y, criterion, model_to_save, model_to_load]
        self.dic['int'] = [resolution, random_seed, num_nonls, num_trees, depth]
        self.dic['int_float'] = [size, no_value]
        self.dic['list'] = [bounding_box, not_included_pred_data, not_included_train_data]

        self.dic_steps = {}
        self.dic_steps['general'] = []
        self.dic_steps['run_purpose'] = [training_dataset, prediction_dataset, map_generation]

    def check_extension(self):
        for path in [data_summary_path, key_to_include_path, path_landslide_database, path_train]:
            if path is not None:
                if '.' in path:
                    if len(path.split('.')) != 2:
                        self.logger.error(path + ': Paths must not contain full stops!')
                        self.error = True
                    else:
                        if path.split('.')[1] != 'csv':
                            self.logger.error(path + ': wrong file format! Needs to be csv')
                            self.error = True
                else:
                    self.logger.warning(path + ': Directory is given. Generic file name will be used')
            
        for path in [path_pred, path_nonls_locations]:
            if path is not None:
                if '.' in path:
                    if len(path.split('.')) != 2:
                        self.logger.error(path + ': Paths must not contain full stops!')
                        self.error = True
                    else:
                        if path.split('.')[1] != 'nc':
                            self.logger.error(path + ': wrong file format! Needs to be nc')
                            self.error = True
                else:
                    self.logger.warning(path + ': Directory is given. Generic file name will be used')

    def check_bools(self):
        self.logger.info("Start checking bools")
        for key in self.dic['bool']:
            if key is not None:
                if type(key) is not bool:
                    self.logger.info(key + ': not a bool')
                    self.error = True

    def check_list(self):
        self.logger.info("Start checking list")
        for key in self.dic['list']:
            if key is not None:
                if type(key) is not list:
                    self.logger.info(key + ': not a list')
                    self.error = True

    def check_int(self):
        self.logger.info("Start checking integers")
        for key in self.dic['int']:
            if key is not None:
                if type(key) is not int:
                    self.logger.info(key + ': not an integer')
                    self.error = True

    def check_int_float(self):
        self.logger.info("Start checking integers and floats")
        for key in self.dic['int_float']:
            if key is not None:
                if type(key) is not int and type(key) is not float:
                    self.logger.info(key + ': not an integer or float')
                    self.error = True

    def check_string(self):
        self.logger.info("Start checking strings")
        for key in self.dic['str']:
            if key is not None:
                if type(key) is not str:
                    self.logger.info(key + ': not a string')
                    self.error = True

    def check_path(self):
        self.logger.info("Start checking paths")
        for key in self.dic['path']:
            if key is not None:
                self.logger.info(key)
                if type(key) is not str:
                    self.logger.info(key + ': path is not a string')
                    self.error = True
                else:   
                    if key == path_train and training_dataset is True:
                        pass
                    elif key == path_pred and prediction_dataset is True:
                        pass
                    else:
                        if not os.path.exists(key):
                            self.logger.error(key + ': path could not be found!')
                            self.error = True
    
    def check_bb(self):
        
        if bounding_box is not None:
            if bounding_box[1] >= bounding_box[0]:
                self.logger.error('Careful! South coordinate north of north coordinate!')
                self.error = True
                
            if bounding_box[2] >= bounding_box[3]:
                if (((bounding_box[2] < 0 and bounding_box[2] > -10) and (bounding_box[3] > 0 and bounding_box[3] < 10))
                or ((bounding_box[2] > 0 and bounding_box[2] > 170) and (bounding_box[3] < 0 and bounding_box[3] < -170))):
                    self.logger.warning('Careful! Please check east and west coordinates!')
                else:
                    self.logger.error('Careful! West coordinate east of east coordinate!')

    def check_path_extension_geosummary(self):

        self.logger.info('Start checking paths in geospatial data summary')
        if data_summary_path is not None and key_to_include_path is not None:
            if os.path.exists(data_summary_path) and os.path.exists(key_to_include_path):
                if data_summary_path.split('.')[1] != 'csv' and key_to_include_path.split('.')[1] != 'csv':
                    summary = pd.read_csv(data_summary_path)
                    keys_to_include = pd.read_csv(key_to_include_path)
                    for key in list(keys_to_include['keys_to_include']):
                        idx = list(summary['keys']).index(key)

                        if summary.at[idx, 'path'].split('.')[1] not in ['nc', 'tif', 'tiff']:
                            self.logger.error(key + ': Wrong file format!')
                            self.error = True

                        if not os.path.exists(summary.at[idx, 'path']):
                            self.logger.error(key + ': File cannot be found!')
                            self.error = True