import numpy as np
import json
import csv
import pickle
import argparse
from nlgeval import NLGEval
from tqdm import tqdm
from bert_score import BERTScorer
import os.path
from os import path
import random

random.seed(42)

parser = argparse.ArgumentParser()
parser.add_argument(
    "--annotated_files", default="annotated/merged_annotated_dd.json"
)
parser.add_argument(
    "--dump_name", default="reseval_result/dd/reference_based.json"
)
args = parser.parse_args()
assert not os.path.exists(args.dump_name)
with open(args.annotated_files, "r") as f:
    data_metrics = json.load(f)

nlgeval = NLGEval(metrics_to_omit=["CIDEr", "SkipThoughtCS"])
scorer = BERTScorer(lang="en", device="cuda", rescale_with_baseline=True)

# compute nlgeval metrics
def compute_metrics(data, include_bert_score=False):
    for j, row in enumerate(tqdm(data)):
        ref, res = [row["reference"]], row["response"]
        if ref == [] or ref[0].strip() == "":
            ref = [res.strip()]
        row["metrics"] = nlgeval.compute_individual_metrics(ref=ref, hyp=res)
        if include_bert_score:
            P_mul, R_mul, F_mul = scorer.score([res], [ref])
            P_mul, R_mul, F_mul = (
                P_mul.data.cpu().item(),
                R_mul.data.cpu().item(),
                F_mul.data.cpu().item(),
            )
            row["metrics"]["bert_prec"] = P_mul
            row["metrics"]["bert_rec"] = R_mul
            row["metrics"]["bert_f1"] = F_mul


compute_metrics(data_metrics, include_bert_score=True)


# compute correlation metrics
import scipy


def compute_correlations(
    data, metric="Bleu_2", method_cnt_check=5, method=None
):
    # todo: method_cnt_check
    vals_aut = []
    vals_hum = []
    for row in data:
        if method:
            if row["model"] != method:
                continue
        vals_aut.append(float(row["metrics"][metric]))
        vals_hum.append(float(row["score"]))
    pearsonr = scipy.stats.pearsonr(np.array(vals_aut), np.array(vals_hum))
    spearmanr = scipy.stats.spearmanr(np.array(vals_aut), np.array(vals_hum))
    kendall_tau = scipy.stats.kendalltau(
        np.array(vals_aut), np.array(vals_hum)
    )
    # correlation, pvalue
    return {
        "spearmanr": spearmanr,
        "pearsonr": pearsonr,
        "kendall_tau": kendall_tau,
    }


metrics_of_interest = [
    "Bleu_1",
    "Bleu_2",
    "Bleu_3",
    "Bleu_4",
    "ROUGE_L",
    "METEOR",
    "EmbeddingAverageCosineSimilairty",
    "bert_prec",
    "bert_rec",
]

ret = {}
for k in metrics_of_interest:
    info = compute_correlations(data_metrics, metric=k, method_cnt_check=5)
    ret[k] = {
        "pearsonr": info["pearsonr"][0],
        "pearsonr_pvalue": info["pearsonr"][1],
        "spearmanr": info["spearmanr"].correlation,
        "spearmanr_pvalue": info["spearmanr"].pvalue,
        "kendall_tau": info["kendall_tau"].correlation,
        "kendall_tau_pvalue": info["kendall_tau"].pvalue,
    }

if args.dump_name is not None:
    with open(args.dump_name, "w") as f:
        json.dump(ret, f)


print(json.dumps(ret, indent=2))
