Skip to content
Snippets Groups Projects
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
shire.py 3.76 KiB
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import settings
import pandas as pd
import os

from create_training_data import create_training_data
from create_prediction_data import create_prediction_data
from RandomForest import RandomForest
from check_user_input import check_general_settings
from utilities.initialise_log import save_log

"""
    This script controls the hazard mapping framework SHIRE.
    Adapt settings.py before running this script. Ensure a data summary csv file
    and a csv file containing the keys to include.
    For more information please refer to the user manual.
"""

print('SHIRE - Landslide hazard mapping framework')
print('If you have not prepared settings.py and \
      the necessary csv files, stop the script.')
      
# Check user input
s = check_general_settings()
if s.error:
    print('Please check settings.py again, there are errors listed in the log.')
else:
    if settings.training_dataset or settings.map_generation:
        save_path = os.path.dirname(settings.path_train) + '/shire_run.log'
    elif settings.prediction_dataset:
        save_path = os.path.dirname(settings.path_pred) + '/shire_run.log'
        
    if os.path.exists(save_path):
        os.remove(save_path)
    logger = save_log(save_path)
    
    settings.export_variables(logger)
    
    if settings.training_dataset:
        print('Training dataset will be generated')
        logger.info('Training dataset generation started')
        if settings.preprocessing is None:
            if settings.preprocessing == 'cluster':
                cluster = True
                interpolation = True
            elif settings.preprocessing == 'interpolation':
                cluster = False
                interpolation = True
            elif settings.preprocessing == 'no_interpolation':
                cluster = False
                interpolation = False
    
        s = create_training_data(
                     from_scratch=settings.train_from_scratch,
                     delete=settings.train_delete,
                     data_to_handle=list(pd.read_csv(settings.key_to_include_path)['keys_to_include']),
                     cluster=settings.cluster,
                     interpolation=settings.interpolation,
                     preprocessing=settings.preprocessing,
                     log=logger)
    
        print('Training dataset successfully created')
        logger = s.logger
        logger.info('Training dataset successfully created')
    
    if settings.prediction_dataset:
        print('Prediction dataset will be generated')
        logger.info('Prediction dataset generation started')
        
        s = create_prediction_data(
            from_scratch=settings.pred_from_scratch,
            delete=settings.pred_delete,
            log=logger)
    
        print('Prediction dataset successfully created')
        logger = s.logger
        logger.info('Prediction dataset successfully created')
    if settings.map_generation:
        print('Map will be generated')
        logger.info('Map generation started')
    
        if settings.parallel:
            print('Prediction will run in parallel')
            logger.info('Prediction will run in parallel')
        if settings.RF_training:    
            logger.info('Random Forest training is launched')
            s =  RandomForest('train_test', parallel=settings.parallel, log=logger)
            logger = s.logger
        if settings.RF_prediction:
            logger.info('Random Forest prediction in launched')
            s = RandomForest('prediction', parallel=settings.parallel, log=logger)
            logger = s.logger
    
        print('Map successfully created')
        logger.info('Map successfully created')
        
    for handler in logger.handlers:
            handler.close()
            logger.removeHandler(handler)