import numpy as np
from sklearn import linear_model
import os
from os.path import join
import json

datasets_dir_path = os.path.dirname(__file__)
linear_weight_save_path = join(datasets_dir_path, 'linear_weight.json')

def stdError_func(y_test, y):
    return np.sqrt(np.mean((y_test - y) ** 2))

def R2_1_func(y_test, y):
    return 1 - ((y_test - y) ** 2).sum() / ((y.mean() - y) ** 2).sum()

def R2_2_func(y_test, y):
    y_mean = np.array(y)
    y_mean[:] = y.mean()
    return 1 - stdError_func(y_test, y) / stdError_func(y_mean, y)

def load_model_score(json_path):
    with open(json_path, 'r')as f:
        json_scores = json.load(f)
    keys = json_scores[0].keys()
    usl_scores = {}
    for key in keys:
        usl_scores[key] = []

    for score in json_scores:
        for k in keys:
            usl_scores[k].append(score[k])

    return usl_scores

def load_human_scores(json_path):
    with open(json_path, 'r')as f:
        all_samples = json.load(f)
    keys = all_samples[0].keys()
    scores_dict = {}
    for key in keys:
        scores_dict[key] = []

    for sample in all_samples:
        for k in keys:
            scores_dict[k].append(sum(sample[k]) / len(sample[k]))
    return scores_dict

def load_json(load_path):
    with open(load_path, 'r')as f:
        json_lines = json.load(f)
    return json_lines

def linear_regr(model_scores_dict, human_score, quality):
    score_lst = []
    for k in model_scores_dict.keys():
        score_lst.append(model_scores_dict[k])
    x_scores = np.array(score_lst).T
    y_score = np.array(human_score)
    cft = linear_model.LinearRegression()
    cft.fit(x_scores, y_score)
    predict_y = cft.predict(x_scores)
    strError = stdError_func(predict_y, y_score)
    R2_1 = R2_1_func(predict_y, y_score)
    R2_2 = R2_2_func(predict_y, y_score)
    score = cft.score(x_scores, y_score)

    if quality not in quality_record.keys():
        quality_record[quality] = []

    quality_record[quality].append(cft.coef_.tolist())

    return cft.coef_.tolist()

def get_single_quality_weight(record, metric_names):
    for k in record.keys():
        record[k] = np.mean(np.array(record[k]), axis=0).tolist()
        record[k] = [round(i, 2) for i in record[k]]
        metric_dict = {}
        for cnt, name in enumerate(metric_names):
            metric_dict[name] = max(record[k][cnt], 0)
        record[k] = metric_dict

    return record

def save_json(lst, save_path):
    with open(save_path, 'w')as f:
        json.dump(lst, f)


if __name__ == '__main__':

      '''
        1.Train FI/NUF/CR/IES metric on NUF/CR/IES Datasets for inner linear weights
                FI = W1 ∗ D-PPL +W2 ∗ LT R + W3 ∗ LR
                 NUF = W4 ∗ LSC + W5 ∗ V UP + W6 ∗ 5-NUF
                CR = W7 ∗ GRADE + W8 ∗ AB-AC +W9 ∗ AB-BA
                IES = W10 ∗ Dist-n +W11 ∗ D-MLM W12 ∗ 5-IES

         python regression.py --model_score_path   --human_anno_path

        2.Train IM2 framework on Overall Datasets to get linear weights

        IM2 = W13 ∗ FI + W14 ∗ NUF + W15 ∗ CR + W16 ∗ IES
        
        python regression.py --model_score_path   --human_anno_path

      
      '''
      metric_name = ['grade_s', 'abac_s', 'abba_s']  # all xx_s in total 12,  represent sub metric score

      # to get the linear weights for CR-Metric,,,w1*grade_s+w2*abac_s+w3*abba_s = CR-annotation

      model_dir_path = join(datasets_dir_path, 'metric-score')

      for name in metric_name:

          model_score_path = join(model_dir_path, "%s_score.json"%name)
          if not os.path.exists(model_score_path):
              continue
          human_anno_path = join(datasets_dir_path, 'dataset/CR/label/txt')

          model_scores = load_model_score(model_score_path)

          human_scores = load_human_scores(human_anno_path)
          spear_lst = []

          for k, v in human_scores.items():
             weights = linear_regr(model_scores, v, k)
             print(k, weights)

          avg_quality_weight = get_single_quality_weight(quality_record, saved_metirc)
          save_json(avg_quality_weight, linear_weight_save_path)
