'''
This script reads the data from the xml files and saves them as csv files.
The script reads the data from the Data folder.
The script reads the data from the following folders:
    - train
    - test
The script reads the data from the following files:
    - parameter_input.xlsx
The script reads the data from the following files:
    - velocities: velocity_field_{index}.xml
    - pressure: pressure_field_{index}.xml
The script saves the data in the following folders:
    - matrices
The script saves the data in the following files:
    - velocity.csv
    - pressure.csv
    - velocity_scaled.csv
    - pressure_scaled.csv
The script scales the data using the StandardScaler from sklearn.
The script saves the scaler in the following files:
    - scaler_velocity.pkl
    - scaler_pressure.pkl
'''
import numpy as np
import pandas as pd
from sklearn.preprocessing import StandardScaler
import joblib
from miscellaneous.DataPreparation import load_xml_to_numpy, get_matrix_size
import os
import argparse
import shutil

# Degree Elevation - Command Line Input
parser = argparse.ArgumentParser(
        description="Script that reads the Elevation Degrees from CMD"
    )
parser.add_argument("--e", required=False, type=int)
parser.add_argument("--r", required=False, type=int)
args = parser.parse_args()
if args.e is not None:
    DEGREE_ELEVATIONS = args.r
else:
    DEGREE_ELEVATIONS = 1
if args.r is not None:
    H_REFINEMENTS = args.r
else:
    H_REFINEMENTS = 0
print('DEGREE_ELEVATIONS:', DEGREE_ELEVATIONS)
print('H_REFINEMENTS:', H_REFINEMENTS)
DataFolder = f'Data/H_REFINEMENTS_{H_REFINEMENTS}'

# Sizes of the microstructures
VELOCITY_SIZE = get_matrix_size(f'{DataFolder}/train/velocities/velocity_field_0.xml')
PRESSURE_SIZE = get_matrix_size(f'{DataFolder}/train/pressure/pressure_field_0.xml')

for stage in ['train', 'test']:
    df = pd.read_excel(f'{DataFolder}/{stage}/parameter_input.xlsx').transpose()
    N_FILES_STAGE = df.shape[1]

    # Initialize empty arrays
    velocity_matrix = np.empty((VELOCITY_SIZE, N_FILES_STAGE), np.float32)
    pressure_matrix = np.empty((PRESSURE_SIZE, N_FILES_STAGE), np.float32)

    for index in range(N_FILES_STAGE):
        file_identifier = f'{index}.xml'

        velocity_file = f'{DataFolder}/{stage}/velocities/velocity_field_{file_identifier}'
        pressure_file = f'{DataFolder}/{stage}/pressure/pressure_field_{file_identifier}'

        velocity_matrix[:,index] = load_xml_to_numpy(velocity_file)
        pressure_matrix[:,index] = load_xml_to_numpy(pressure_file)

        # not necessary
        # Historically came as in the beginning some rows had to be removed
        # Todo: change this
        if stage == 'train':
            velocity_matrix_train = velocity_matrix
            pressuer_matrix_train = pressure_matrix
        elif stage == 'test':
            velocity_matrix_test = velocity_matrix
            pressure_matrix_test = pressure_matrix

    # Save the data
    if os.path.exists(f'{DataFolder}/{stage}/matrices'):
        shutil.rmtree(f'{DataFolder}/{stage}/matrices')
    os.mkdir(f'{DataFolder}/{stage}/matrices')

    np.savetxt(f'{DataFolder}/{stage}/matrices/velocity.csv', velocity_matrix)
    np.savetxt(f'{DataFolder}/{stage}/matrices/pressure.csv', pressure_matrix)

    # Scale the data
    scaler_velocity = StandardScaler()
    scaler_pressure = StandardScaler()

    # Fit the scaler
    # ! Fit with the first column, is there a better way?
    scaler_velocity = scaler_velocity.fit(velocity_matrix[:,0].reshape(-1,1))
    scaler_pressure = scaler_pressure.fit(pressure_matrix[:,0].reshape(-1,1))

    # Scale whole matrix
    if stage == 'train':
        velocity_matrix_train_scaled = np.empty_like(velocity_matrix)
        for i in range(velocity_matrix.shape[1]):
            velocity_matrix_train_scaled[:,i] = scaler_velocity.transform(velocity_matrix[:,i].reshape(-1,1)).reshape(-1)
        pressure_matrix_train_scaled = np.empty_like(pressure_matrix)
        for i in range(pressure_matrix.shape[1]):
            pressure_matrix_train_scaled[:,i] = scaler_pressure.transform(pressure_matrix[:,i].reshape(-1,1)).reshape(-1)

        # Save the scaler
        joblib.dump(scaler_velocity, f'{DataFolder}/{stage}/scaler_velocity.pkl')
        joblib.dump(scaler_pressure, f'{DataFolder}/{stage}/scaler_pressure.pkl')

        # Save the scaled data
        np.savetxt(f'{DataFolder}/{stage}/matrices/velocity_scaled.csv', velocity_matrix_train_scaled)
        np.savetxt(f'{DataFolder}/{stage}/matrices/pressure_scaled.csv', pressure_matrix_train_scaled)
    elif stage == 'test':
        velocity_matrix_test_scaled = np.empty_like(velocity_matrix)
        for i in range(velocity_matrix.shape[1]):
            velocity_matrix_test_scaled[:,i] = scaler_velocity.transform(velocity_matrix[:,i].reshape(-1,1)).reshape(-1)
        pressure_matrix_test_scaled = np.empty_like(pressure_matrix)
        for i in range(pressure_matrix.shape[1]):
            pressure_matrix_test_scaled[:,i] = scaler_pressure.transform(pressure_matrix[:,i].reshape(-1,1)).reshape(-1)

        # Save the scaled data
        np.savetxt(f'{DataFolder}/{stage}/matrices/velocity_scaled.csv', velocity_matrix_test_scaled)
        np.savetxt(f'{DataFolder}/{stage}/matrices/pressure_scaled.csv', pressure_matrix_test_scaled)