import pandas as pd
from typing import List, Dict, Tuple
from sklearn.model_selection import train_test_split
import numpy as np
import re

class TabularDataImporter:
    
    def __init__(self, csvfile:str):
        self.filename:str = csvfile
        self.dataFrame:pd.DataFrame = pd.read_csv(csvfile)


    # setting nan to zero

    def preprocess(self, predictor:str)->None:

        def convert(series):
            try:
                series:pd.Series = series.replace(",", ".")
                res:float = float(series.split()[0])
            except:
                if series == "n.b." or np.isnan(series):
                    return 0
                else:
                    res = series
            return res

        def ConvertCols(colname:str, data:pd.DataFrame):
            data[colname] = data[colname].apply(lambda x: convert(x)) #if x.split()[0] != "n.b." else 0)

        ConvertCols(predictor, self.dataFrame)

    def preprocess_all(self)->None:

        def convert(series):

            #print(series, type(series))
            


            if type(series) == float:
                if np.isnan(series):
                    res = 0
                else:
                    res = series

            elif type(series) == str:
                series = re.sub("\s*\([\w\W]*\)", "", series)

                if series == "n.b.":
                    res = 0
            
                elif any([char.isdigit() for char in series]):
                    series:pd.Series = series.replace(",", ".")
                    res:float = float(series.split()[0])
                else:
                    series = re.sub("[\s]", "_", series)
                    res = series
                    
            elif type(series) == int:
                res = series
                
            return res

        def ConvertCols(colname:str, data:pd.DataFrame):
            data[colname] = data[colname].apply(lambda x: convert(x)) #if x.split()[0] != "n.b." else 0)

        for col in self.dataFrame:
            #print(col)
            ConvertCols(col, self.dataFrame)


    # one-hot encoding of categorical features

    def encodeCategorical(self)->None:

        allCatPredictors:List[str] = list(self.dataFrame.select_dtypes(include=["object"]).columns)
        for predictor in allCatPredictors:
            self.dataFrame = pd.get_dummies(self.dataFrame, columns=[predictor], prefix=predictor, prefix_sep='_')

    def encodeBool(self)->None:
        allBoolPredictors:List[str] = list(self.dataFrame.select_dtypes(include=["bool"]).columns)
        for b in allBoolPredictors:
            self.dataFrame[b] = self.dataFrame[b].astype("int")
    


    def split_train_test_set(self, predictors:List[str], response:List[str])->List[pd.DataFrame]:

        self.dataFrame = self.dataFrame.drop(self.dataFrame['LeistungPS'].idxmax())
        self.dataFrame = self.dataFrame.drop(self.dataFrame.loc[self.dataFrame['Beschleunigung']==0].index)
        self.dataFrame = self.dataFrame.drop(self.dataFrame.loc[self.dataFrame['Gesamtgewicht']==0].index)

        if len(predictors) == 0:
            
            X = self.dataFrame.drop(response, axis=1)
            y = self.dataFrame[response]
        else:
            X = self.dataFrame[predictors]
            y = self.dataFrame[response]
        X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=0.8, test_size=0.2)
        return (X_train, X_test, y_train, y_test)
         


    def prepare(self):
        self.preprocess_all()
        self.encodeCategorical()
        self.encodeBool()

        tabular = self.dataFrame

        tabular = tabular.drop(tabular['LeistungPS'].idxmax())
        tabular = tabular.drop(tabular.loc[tabular['Beschleunigung']==0].index)
        tabular = tabular.drop(tabular.loc[tabular['Gesamtgewicht']==0].index)
        tabular = tabular.drop(tabular.loc[tabular['LeistungPS']==0].index)

        return tabular


    def get_mean(self, feature):
        tab = self.prepare()
        res = tab[feature].mean()
        return res

    def get_std(self, feature):
        tab = self.prepare()
        res = tab[feature].std()
        return res

    def get_variance(self, feature):
        tab = self.prepare()
        res = tab[feature].var()
        return res

    def get_median(self, feature):
        tab = self.prepare()
        res = tab[feature].median()
        return res


#tabd = TabularDataImporter("data.csv")
# tabd.preprocess_all()
# tabd.encodeCategorical()
