import numpy as np
import pandas as pd
from sklearn.preprocessing import StandardScaler
import joblib
from miscellaneous.DataPreparation import load_xml_to_numpy, get_matrix_size
import os
import argparse
import shutil

# Degree Elevation
parser = argparse.ArgumentParser(
        description="Script that reads the Elevation Degrees from CMD"
    )
parser.add_argument("--e", required=False, type=int)
parser.add_argument("--r", required=False, type=int)
args = parser.parse_args()
if args.e is not None:
    DEGREE_ELEVATIONS = args.r
else:
    DEGREE_ELEVATIONS = 1
if args.r is not None:
    H_REFINEMENTS = args.r
else:
    H_REFINEMENTS = 0
print('DEGREE_ELEVATIONS:', DEGREE_ELEVATIONS)
print('H_REFINEMENTS:', H_REFINEMENTS)
DataFolder = f'Data/H_REFINEMENTS_{H_REFINEMENTS}'

# Sizes of the microstructures
VELOCITY_SIZE = get_matrix_size(f'{DataFolder}/train/velocities/velocity_field_0.xml')
PRESSURE_SIZE = get_matrix_size(f'{DataFolder}/train/pressure/pressure_field_0.xml')

for stage in ['train', 'test']:
    df = pd.read_excel(f'{DataFolder}/{stage}/parameter_input.xlsx').transpose()
    N_FILES_STAGE = df.shape[1]

    # Initialize empty arrays
    velocity_matrix = np.empty((VELOCITY_SIZE, N_FILES_STAGE), np.float32)
    pressure_matrix = np.empty((PRESSURE_SIZE, N_FILES_STAGE), np.float32)

    for index in range(N_FILES_STAGE):
        file_identifier = f'{index}.xml'

        velocity_file = f'{DataFolder}/{stage}/velocities/velocity_field_{file_identifier}'
        pressure_file = f'{DataFolder}/{stage}/pressure/pressure_field_{file_identifier}'

        velocity_matrix[:,index] = load_xml_to_numpy(velocity_file)
        pressure_matrix[:,index] = load_xml_to_numpy(pressure_file)

        # not necessary
        # Historically came as in the beginning some rows had to be removed
        # Todo: change this
        if stage == 'train':
            velocity_matrix_train = velocity_matrix
            pressuer_matrix_train = pressure_matrix
        elif stage == 'test':
            velocity_matrix_test = velocity_matrix
            pressure_matrix_test = pressure_matrix

    # Save the data
    if os.path.exists(f'{DataFolder}/{stage}/matrices'):
        shutil.rmtree(f'{DataFolder}/{stage}/matrices')
    os.mkdir(f'{DataFolder}/{stage}/matrices')

    np.savetxt(f'{DataFolder}/{stage}/matrices/velocity.csv', velocity_matrix)
    np.savetxt(f'{DataFolder}/{stage}/matrices/pressure.csv', pressure_matrix)

    # Scale the data
    scaler_velocity = StandardScaler()
    scaler_pressure = StandardScaler()

    # Fit the scaler
    # ! Fit with the first column, is there a better way?
    scaler_velocity = scaler_velocity.fit(velocity_matrix[:,0].reshape(-1,1))
    scaler_pressure = scaler_pressure.fit(pressure_matrix[:,0].reshape(-1,1))

    # Scale whole matrix
    if stage == 'train':
        velocity_matrix_train_scaled = np.empty_like(velocity_matrix)
        for i in range(velocity_matrix.shape[1]):
            velocity_matrix_train_scaled[:,i] = scaler_velocity.transform(velocity_matrix[:,i].reshape(-1,1)).reshape(-1)
        pressure_matrix_train_scaled = np.empty_like(pressure_matrix)
        for i in range(pressure_matrix.shape[1]):
            pressure_matrix_train_scaled[:,i] = scaler_pressure.transform(pressure_matrix[:,i].reshape(-1,1)).reshape(-1)

        # Save the scaler
        joblib.dump(scaler_velocity, f'{DataFolder}/{stage}/scaler_velocity.pkl')
        joblib.dump(scaler_pressure, f'{DataFolder}/{stage}/scaler_pressure.pkl')

        # Save the scaled data
        np.savetxt(f'{DataFolder}/{stage}/matrices/velocity_scaled.csv', velocity_matrix_train_scaled)
        np.savetxt(f'{DataFolder}/{stage}/matrices/pressure_scaled.csv', pressure_matrix_train_scaled)
    elif stage == 'test':
        velocity_matrix_test_scaled = np.empty_like(velocity_matrix)
        for i in range(velocity_matrix.shape[1]):
            velocity_matrix_test_scaled[:,i] = scaler_velocity.transform(velocity_matrix[:,i].reshape(-1,1)).reshape(-1)
        pressure_matrix_test_scaled = np.empty_like(pressure_matrix)
        for i in range(pressure_matrix.shape[1]):
            pressure_matrix_test_scaled[:,i] = scaler_pressure.transform(pressure_matrix[:,i].reshape(-1,1)).reshape(-1)

        # Save the scaled data
        np.savetxt(f'{DataFolder}/{stage}/matrices/velocity_scaled.csv', velocity_matrix_test_scaled)
        np.savetxt(f'{DataFolder}/{stage}/matrices/pressure_scaled.csv', pressure_matrix_test_scaled)