import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from tabularDataImporter import TabularDataImporter
import tensorflow as tf
from typing import List, Dict, Tuple

from tensorflow import keras
from tensorflow.keras import layers
from matplotlib.colors import ListedColormap


from tabularDataImporter import TabularDataImporter
from dnn_regression import DNNRegressor
from linear_regression import LinearRegressor
from nn_linear_regression import NN1Regressor
from polynomial_regression import PolynomialRegressor
from matplotlib.colors import ListedColormap
from matplotlib.lines import Line2D


class Regressiontest:
    def __init__(self, RegressionModels:List[object], csvfile:str, predictors:List[str], response:List[str]):

        self.regressors = RegressionModels
        self.csvfile = csvfile
        self.dataFrame = TabularDataImporter(self.csvfile).prepare()
        self.predictors = predictors
        self.response = response
        self.models = [regressor(csvfile, predictors, response) for regressor in self.regressors]
        for mod in self.models:
            mod.train(50)
        #map(lambda x: x.train(50), self.models)
        self.variable_models = []


    def set_predictor_and_response(self, pred, resp):
        self.predictors = pred
        self.response = resp

    def train_var_models(self, epochs=1):
        for i in self.variable_models:
            i.train(epochs)


    def create_graph(self)->None:

        if len(self.predictors) > 2:
            return 0


        sns.set_style('whitegrid')
        
        sns.set(font_scale = 2)



        
        fig, ax = plt.subplots()
        graphs = []
        labels = []
        symbols = ["solid", "dotted", "dashed", "dashdot"]
        
        plt.scatter(self.models[0].data.dataFrame[self.predictors],self.models[0].data.dataFrame[self.response])
        for i in range(len(self.models)):
            regressor = self.models[i]
            

            pred_min:float = regressor.data.dataFrame[self.predictors].min()
            pred_max:float = regressor.data.dataFrame[self.predictors].max()

            linedata = np.arange(pred_min[0], pred_max[0], 1)
            linedata = pd.DataFrame(linedata)
            y_pred = regressor.response_predict(linedata)

            line = plt.plot(linedata, y_pred, linewidth = 3, linestyle = symbols[i], color="black", label=regressor.toString().replace("Regressor", ""))
            graphs.append(line)
            labels.append(regressor.toString().replace("Regressor", ""))

            print("####")
            print(regressor.toString())
            eval = regressor.evaluate()
            print(f"mse: {eval[0]}; r2: {eval[1]}")

        
            ax.legend()
        plt.show()


    def add_var_model_graphs(self, ax):

        graphs = []
        labels = []
        symbols = ["solid", "dotted", "dashed", "dashdot"]
        colors = ["lightgreen", "red", "yellow"]
        for i in range(len(self.variable_models)):
            regressor = self.variable_models[i]
            
            

            pred_min:float = regressor.data.dataFrame[self.predictors].min()
            pred_max:float = regressor.data.dataFrame[self.predictors].max()

            
            linedata = np.arange(pred_min[0], pred_max[0], 1)
            linedata = pd.DataFrame(linedata)
            y_pred = regressor.response_predict(linedata)

            line = ax.plot(linedata, y_pred, linewidth = 3, linestyle = symbols[i], color=colors[i], label=regressor.toString().replace("Regressor", ""))
            graphs.append(line)
            labels.append(regressor.toString().replace("Regressor", ""))



    def plot_triple_3_features(self):


    
        
        tabular = self.dataFrame

        tabular = tabular.drop(tabular['LeistungPS'].idxmax())
        tabular = tabular.drop(tabular.loc[tabular['Beschleunigung']==0].index)
        tabular = tabular.drop(tabular.loc[tabular['Gesamtgewicht']==0].index)



        x = tabular["LeistungPS"]
        y = tabular["Gesamtgewicht"]
        z = tabular["Beschleunigung"]

        sns.set(font_scale = 2)



        fig, axes = plt.subplots(1, 3, figsize=(18, 10))

        cmap = ListedColormap(sns.diverging_palette(250, 30, l=65, center="dark").as_hex())    #.diverging_palette(250, 30, l=65).as_hex())   #seaborn.color_palette("husl", 256).as_hex())



        self.set_predictor_and_response(["LeistungPS"], ["Gesamtgewicht"])
        self.variable_models = [regressor(self.csvfile, ["LeistungPS"], ["Gesamtgewicht"]) for regressor in self.regressors]
        self.train_var_models(epochs=50)
        self.add_var_model_graphs(axes[0])
        a = axes[0].scatter(x, y, s=80, c=z, marker='+', cmap=cmap, alpha=1)
        lotusA = axes[0].scatter(136, 875, marker='o', s=120, color="g")
        axes[0].set_xlabel("power (hp)")
        axes[0].set_ylabel("weight (kg)")
        axes[0].set_title("(a)", va="bottom")




        self.set_predictor_and_response(["Gesamtgewicht"], ["Beschleunigung"])
        self.variable_models = [regressor(self.csvfile, ["Gesamtgewicht"], ["Beschleunigung"]) for regressor in self.regressors]
        self.train_var_models(epochs=50)
        self.add_var_model_graphs(axes[1])
        
        b = axes[1].scatter(y, z, s=80, c=x, marker='+', cmap=cmap, alpha=1)
        lotusB = axes[1].scatter(875, 6.5, marker='o', s=120, color="g")
        axes[1].set_xlabel("weight (kg)")
        axes[1].set_ylabel("acceleration (sec.)")
        axes[1].set_title("(b)")




        self.set_predictor_and_response(["LeistungPS"], ["Beschleunigung"])
        self.variable_models = [regressor(self.csvfile, ["LeistungPS"], ["Beschleunigung"]) for regressor in self.regressors]
        self.train_var_models(epochs=50)
        self.add_var_model_graphs(axes[2])
        

        c = axes[2].scatter(x, z, s=80, c=y, marker='+', cmap=cmap, alpha=1)
        lotusC = axes[2].scatter(136, 6.5, marker='o', s=120, color="g")
        axes[2].set_xlabel("power (hp)")
        axes[2].set_ylabel("acceleration (sec.)")
        axes[2].set_title("(c)")

        # create legend 

        custom_lines = [Line2D([0], [0], linestyle="solid", color="lightgreen", lw=4),
                Line2D([0], [0], linestyle="dotted", color="red", lw=4),
                Line2D([0], [0], linestyle="dashed", color="yellow", lw=4)]

        fig.legend(custom_lines, [x.toString().replace("Regressor", "") for x in self.variable_models])


        plt.show()

        

    def predict_single_point(self, datapoint):

        res = {}
        for i in range(len(self.models)):
            regressor = self.models[i]
            
            keys = [x[0] for x in datapoint]
            data = [x[1] for x in datapoint]
            linedata = pd.DataFrame(data=[data], columns=keys)
            y_pred = regressor.response_predict(linedata)

            res[regressor.toString()] = y_pred[0]
            print(regressor.toString(), y_pred)
        return res
        
                

        

        






#rt = Regressiontest([DNNRegressor, PolynomialRegressor, LinearRegressor], "data.csv", ["LeistungPS"], ["Gesamtgewicht"])
#rt = Regressiontest([NN1Regressor], "data.csv", ["LeistungPS"], ["Hoechstgeschwindigkeit"])
#rt.create_graph()
#rt.plot_triple_3_features()

# print("predicting Gesamtgewicht")
# rt = Regressiontest([DNNRegressor, PolynomialRegressor, LinearRegressor], "data.csv", ["LeistungPS", "Beschleunigung"], ["Gesamtgewicht"])
# rt.predict_single_point((["LeistungPS", 136], ["Beschleunigung",6.5]))
# print("####")
# print("predicting Beschleunigung")
# rt = Regressiontest([DNNRegressor, PolynomialRegressor, LinearRegressor], "data.csv", ["LeistungPS", "Gesamtgewicht"], ["Beschleunigung"])
# rt.predict_single_point((["LeistungPS",136], ["Gesamtgewicht",2500]))
# print("####")
# print("predicting LeistungPS")
# rt = Regressiontest([DNNRegressor, PolynomialRegressor, LinearRegressor], "data.csv", ["Gesamtgewicht", "Beschleunigung"], ["LeistungPS"])
# rt.predict_single_point((["Gesamtgewicht",2500], ["Beschleunigung",6.5]))

# print("predicting Gesamtgewicht by LeistungPS")
# rt = Regressiontest([DNNRegressor, PolynomialRegressor, LinearRegressor], "data.csv", ["LeistungPS"], ["Gesamtgewicht"])
# rt.predict_single_point((["LeistungPS", 136],))
# print("####")
# print("predicting Gesamtgewicht by acceleration")
# rt = Regressiontest([DNNRegressor, PolynomialRegressor, LinearRegressor], "data.csv", ["Beschleunigung"], ["Gesamtgewicht"])
# rt.predict_single_point((["Beschleunigung", 6.5],))
# print("####")

# print("predicting Beschleunigung by LeistungPS")
# rt = Regressiontest([DNNRegressor, PolynomialRegressor, LinearRegressor], "data.csv", ["LeistungPS"], ["Beschleunigung"])
# rt.predict_single_point((["LeistungPS",136],))
# print("####")
# print("predicting Beschleunigung by Gesamtgewicht")
# rt = Regressiontest([DNNRegressor, PolynomialRegressor, LinearRegressor], "data.csv", ["Gesamtgewicht"], ["Beschleunigung"])
# rt.predict_single_point((["Gesamtgewicht",875],))

# print("predicting LeistungPS by Beschleunigung")
# rt = Regressiontest([DNNRegressor, PolynomialRegressor, LinearRegressor], "data.csv", ["Beschleunigung"], ["LeistungPS"])
# rt.predict_single_point((["Beschleunigung",6.5],))
# print("####")
# print("predicting LeistungPS by Gesamtgewicht")
# rt = Regressiontest([DNNRegressor, PolynomialRegressor, LinearRegressor], "data.csv", ["Gesamtgewicht"], ["LeistungPS"])
# rt.predict_single_point((["Gesamtgewicht",875],))
