import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from tabularDataImporter import TabularDataImporter
import tensorflow as tf

from tensorflow import keras
from tensorflow.keras import layers
from typing import List, Dict, Tuple
from tensorflow_addons.metrics import RSquare
            

class NN1Regressor:
    def __init__(self, csvfile:str, predictors:List[str], response:List[str])->pd.DataFrame:
        
        self.data:pd.DataFrame = TabularDataImporter(csvfile)
        self.data.preprocess_all()
        self.data.encodeCategorical()
        self.data.dataFrame["Gesamtgewicht"] = self.data.dataFrame["Gesamtgewicht"].apply(lambda x: x/100) 

        self.X_train, self.X_test, self.y_train, self.y_test = self.data.split_train_test_set(predictors, response)



        
        npdata = np.array(self.X_train)

        if len(predictors) == 0:
            dim0 = len(self.X_train.columns)
        else:
            dim0 = len(predictors)

        data_normalizer = layers.Normalization(input_shape=[dim0,], axis=None)
        data_normalizer.adapt(npdata)
        self.model = tf.keras.Sequential([
        data_normalizer,
        layers.Dense(units=1)
        ])
        self.model.compile(
            optimizer=tf.keras.optimizers.Adam(learning_rate=0.1),
            loss='mean_absolute_error')
                       

    def toString(self):
        return "NN1Regressor"


    def train(self, epochs:int = 1)->None:

        self.history = self.model.fit(
            self.X_train,
            self.y_train,
            epochs=epochs,
            verbose=0,
            validation_split = 0.2)

    def response_predict(self, test_data):
        return  self.model.predict(test_data)



    def analyse_loss(self)->None:
        plt.plot(self.history.history['loss'], label='loss')
        plt.plot(self.history.history['val_loss'], label='val_loss')
        plt.ylim([0, 10])
        plt.xlabel('Epoch')
        plt.ylabel('Error')
        plt.legend()
        plt.grid(True)
        plt.show()


    def evaluate(self): 
            
        mse = tf.keras.losses.MeanSquaredError()
        y_pred = self.model.predict(self.X_test)
        MSE_val = mse(y_pred, self.y_test).numpy()

        r2 = RSquare() #tf.addons.metrics.r_square.RSquare()
        r2.update_state(self.y_test, y_pred)
        r2_coeff = r2.result()
        r2_val = r2_coeff.numpy()




        scalar_test_loss = self.model.evaluate(
                self.X_test,
                self.y_test, verbose=0)
        return MSE_val, r2_val, scalar_test_loss


    def predict_testdata(self):

        linedata = np.linspace(self.X_test.min(), self.X_test.max(), 100)
        y_pred = self.model.predict(linedata)
    
        plt.plot(linedata, y_pred)
        plt.show()


# nn1 = NN1Regressor("data.csv", ["Gesamtgewicht"], ["LeistungPS"]) # empty predictors means all
# nn1.train(15)
# nn1.analyse_loss()
# print(nn1.evaluate())
# nn1.predict_testdata()

