import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from tabularDataImporter import TabularDataImporter
from typing import List, Dict, Tuple
from sklearn import linear_model
from sklearn.metrics import mean_squared_error, r2_score




class LinearRegressor:
    def __init__(self, csvfile:str, predictors:List[str], response:List[str])->pd.DataFrame:
        
        self.data:pd.DataFrame = TabularDataImporter(csvfile)
        self.data.preprocess_all()
        self.data.encodeCategorical()

        self.X_train, self.X_test, self.y_train, self.y_test = self.data.split_train_test_set(predictors, response)
        self.model = linear_model.LinearRegression()

    def toString(self):
        return "LinearRegressor"
   
    def train(self, epochs=None)->None:

        self.model.fit(
            self.X_train,
            self.y_train)

    def response_predict(self, test_data):
        return  self.model.predict(test_data)


    def analyse_loss(self):
        y_pred = self.model.predict(self.X_test)
        plt.scatter(y_pred, self.y_test)
        plt.xlabel('pred')
        plt.ylabel('observed')
        plt.legend()
        plt.grid(True)
        plt.show()



    def evaluate(self):
        
        y_pred = self.model.predict(self.X_test)
        print("Coefficients: \n", self.model.coef_)
        print("Mean squared error: %.2f" % mean_squared_error(self.y_test, y_pred))
        print("Coefficient of determination: %.2f" % r2_score(self.y_test, y_pred))
        return mean_squared_error(self.y_test, y_pred), r2_score(self.y_test, y_pred), self.model.coef_

# lre = LinearRegressor("data.csv", ["LeistungPS"], ["Beschleunigung"])
# lre.train()
# #lre.analyse_loss()
# #lre.evaluate()

# print(pd.DataFrame([1,2,3]).shape)
# print(lre.model.predict(pd.DataFrame([1,2,3])))