Skip to content
Snippets Groups Projects
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
shire.py 4.20 KiB
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import os
import pickle
import os
import tkinter as tk

from create_training_data_gui import *
from create_prediction_data_gui import *
from RandomForest_gui import *

from check_user_input import check_general_settings
from utilities.initialise_log import save_log
from utilities.gui import *

"""
    This script controls the hazard mapping framework SHIRE.
    Ensure a data summary csv file
    and a csv file containing the keys to include have been prepared.
    For more information please refer to the user manual.
"""

if os.path.isfile('tmp_settings.pkl'):
    os.remove('tmp_settings.pkl')

if os.path.isfile('tmp_train.pkl'):
    os.remove('tmp_train.pkl')
    
if os.path.isfile('tmp_pred.pkl'):
    os.remove('tmp_pred.pkl')

if os.path.isfile('tmp_map.pkl'):
    os.remove('tmp_map.pkl')

#Get the general settings
master = tk.Tk()
general_settings(master)
master.mainloop()


s = check_general_settings()

if os.path.exists('shire_run.log'):
    os.remove('shire_run.log')    
logger = save_log('shire_run.log')
logger.info('SHIRE has successfully been launched')
logger.info('User input required')
logger.info('General settings defined')

if s.error:
    logger.info('There is an error in the user input. For more infos check the check_user_input.log')
else:
    if os.path.isfile('tmp_settings.pkl'):
        with open('tmp_settings.pkl', 'rb') as handle:
            properties_settings = pickle.load(handle)
    
    master = tk.Tk()
    if properties_settings['train'] == 1:
        logger.info('Training dataset generation started')
        s = create_training_data(master=master, log=logger)
        os.remove('tmp_train.pkl')
        logger = s.logger
        if properties_settings['pred'] != 1 and properties_settings['map'] != 1:
            for handler in logger.handlers:
                handler.close()
                logger.removeHandler(handler)
    master.destroy()
    
    master = tk.Tk()
    if properties_settings['pred'] == 1:    
        logger.info('Prediction dataset generation started')
        s = create_prediction_data(master=master, log=logger)
        os.remove('tmp_pred.pkl')
        logger = s.logger
        if properties_settings['pred'] != 1 and properties_settings['map'] != 1:
            for handler in logger.handlers:
                handler.close()
                logger.removeHandler(handler)
    master.destroy()
    
    master = tk.Tk()
    if properties_settings['map'] == 1:    
        logger.info('Map generation started')
        with open('tmp_map.pkl', 'rb') as handle:
           properties_map = pickle.load(handle)
                
        if properties_map['training'] == 1 and properties_map['prediction'] == 1:
            for mode in ['train_test', 'prediction']:
                if mode == 'train_test':
                    s = RandomForest(master, mode, log=logger)
                else:
                    if properties_map['parallel'] == 1:
                        s = RandomForest(master, mode, parallel=True, log=logger)
                    else:
                        s = RandomForest(master, mode, log=logger)
        elif properties_map['training'] == 1 and properties_map['prediction'] == 0:
            s = RandomForest(master, 'train_test', log=logger)
        elif properties_map['prediction'] == 1 and properties_map['training'] == 0:
            if properties_map['parallel'] == 1:
               s = RandomForest(master, 'prediction', parallel=True, log=logger)
            else:
               s = RandomForest(master, 'prediction', log=logger)
               
        if s.retrain:
            print('Retrain necessary')
            for mode in ['train_test', 'prediction']:
                if mode == 'train_test':
                    s = RandomForest(master, mode, log=logger, retrain=True)
                else:
                    if properties_map['parallel'] == 1:
                        s = RandomForest(master, mode, parallel=True, log=logger, retrain=True)
                    else:
                        s = RandomForest(master, mode, log=logger, retrain=True)
        
        os.remove('tmp_map.pkl')
        logger = s.logger
        for handler in logger.handlers:
            handler.close()
            logger.removeHandler(handler)