
import argparse
from pprint import pprint

from typing import List, Tuple, Union

from scipy.stats import spearmanr

import pandas as pd

import numpy as np

class CorrelationExecuter:

    def __init__(self, datasets_paths: List[str],
                       reference_path: str,
                       reference_column: str,
                       correlation_algorithm: str = "spearman",
                       lower_indices: bool = True ):

        self.ref_column = reference_column
        self.lower_indices = lower_indices
        self.load_datasets(reference_path, datasets_paths)

        self.number_datasets = len(datasets_paths)
        self.ref_column = reference_column

        algorithms = {
            "spearman": spearmanr
        }

        algorithms_columns = {
            "spearman": ["spearman_corr", "spearman_p"]
        }

        self.algorithm = algorithms[correlation_algorithm]
        self.alg_column = algorithms_columns[correlation_algorithm]


        super().__init__()

    def compare_columns(self, columns: List[str]) -> pd.DataFrame:
        column_totals = np.zeros((len(columns), len(self.dataset)))
        for i, column in enumerate(columns):
            column_stack = np.zeros((self.number_datasets, len(self.dataset)))
            for x in range(self.number_datasets):
                suffix = '' if x == 0 else str(x)
                column_stack[x, :] = self.dataset[column+suffix].to_numpy()
            column_totals[i, :] = column_stack.mean(axis=0)

        output_df = pd.DataFrame(columns=self.alg_column)

        if column_totals.shape[0] == 1:
            alg_result = self.algorithm(a = column_totals.squeeze(),
                                        b = self.dataset[self.ref_column].to_numpy(),
                                        nan_policy = 'raise')
            output_df.loc['res'] = alg_result
        else:
            output_df.loc['max'] = self.algorithm(a = column_totals.max(axis=0),
                                                  b = self.dataset[self.ref_column].to_numpy(),
                                                  nan_policy = 'raise')
            output_df.loc['mean'] = self.algorithm(a = column_totals.mean(axis=0),
                                                   b = self.dataset[self.ref_column].to_numpy(),
                                                   nan_policy = 'raise')
            output_df.loc['min'] = self.algorithm(a = column_totals.min(axis=0),
                                                  b = self.dataset[self.ref_column].to_numpy(),
                                                  nan_policy = 'raise')

        return output_df

    def load_datasets(self, reference_path: str, datasets_paths: List[str]):

        self.dataset = pd.read_csv(reference_path, index_col=0)
        if self.ref_column not in self.dataset.columns:
            raise ValueError(f'{self.ref_column} is not in the reference dataset.')

        if self.lower_indices:
            self.dataset.index = self.dataset.index.str.lower()

        dataset_suffix = 0
        for dataset_path in datasets_paths:
            dataset = pd.read_csv(dataset_path, index_col=0)
            if self.lower_indices:
                dataset.index = dataset.index.str.lower()

            self.dataset = self.dataset.join(dataset,
                                             how="inner",
                                             rsuffix=str(dataset_suffix))
            dataset_suffix +=1



def define_arguments() -> argparse.ArgumentParser:
    parser = argparse.ArgumentParser(description="Generate correlation among columns")

    parser.add_argument("datasets_paths", nargs="+",
                        help="Path to CSV dataset file. First column is used as index for joins.")
    parser.add_argument("-s", "--separator", help="Separator character of the dataset file.", default=",")
    parser.add_argument("-c", "--columns", nargs="+", required=True,
                        help="Pair of columns separated by comma")
    parser.add_argument("-t", "--target-path", required=True,
                        help="Path to CSV file containing the target column. "
                             "The first column is used as index for joining datasets.")
    parser.add_argument("-tc", "--target-column", required=True,
                        help="Path to CSV file containing the target column. "
                             "The first column is used as index for joining datasets.")
    parser.add_argument("-corr", "--correlation", default="spearman",
                       help="Which algorithm to use for correlation")
    parser.add_argument("-o", "--output", required=True,
                        help="Activate verbose mode to output debug functionality.")

    parser.add_argument("-v", "--verbose", action="store_true",
                        help="Activate verbose mode to output debug functionality.")

    return parser

def main():
    args = define_arguments().parse_args()

    corr_exec = CorrelationExecuter(datasets_paths=args.datasets_paths,
                                    reference_path=args.target_path,
                                    reference_column=args.target_column,
                                    correlation_algorithm=args.correlation)

    output_df = corr_exec.compare_columns(args.columns)
    output_df.to_csv(args.output)

if __name__ == "__main__":
    main()
