import json
import seaborn as sns
import matplotlib.pyplot as plt
from collections import defaultdict
import re
from sklearn.metrics import cohen_kappa_score
import krippendorff


ConceptNetRelations = ['AtLocation', 'CapableOf', 'Causes', 'CausesDesire', 'Desires', 'HasA', 'HasPrerequisite',
                       'HasProperty', 'HasSubevent', 'IsA', 'MadeOf', 'MotivatedByGoal', 'NotDesires', 'PartOf', 'ReceivesAction', 'UsedFor']

def transform_samples():
    path = "./sampled_close_triples.json"
    ofile = open("samples.txt", 'w')
    print("start")
    with open(path, 'r') as f:
        samples = json.load(f)
        for sample in samples:
            ofile.write("head: {}, relation: {}, obj: {}, masked sentence: {}, predictions: {}\n".format(*sample))
    ofile.close()
    print("done")

def F_beta_score(beta, precision, novelty):
    return (1+beta*beta)*precision*novelty / (beta* beta * precision + novelty)

def collect_train_dev_triples():
    """
    """
    train_pth = "./data/CKBC/train100k.txt"
    dev_pth = "./data/CKBC/dev_total.txt"
    rel2pair = defaultdict(lambda: set())
    with open(train_pth, 'r') as f:
        for line in f:
            line = line.strip()
            rel, subj, obj, score = line.split('\t')
            if len(subj.split()) > 1:
                continue
            if len(obj.split()) > 1:
                continue
            if not rel in ConceptNetRelations:
                continue
            rel2pair[rel].add((subj.lower(), obj.lower()))
    with open(dev_pth, 'r') as f:
        for line in f:
            line = line.strip()
            rel, subj, obj, label = line.split('\t')
            if label != '1':
                continue
            if len(subj.split()) > 1:
                continue
            if len(obj.split()) > 1:
                continue
            if not rel in ConceptNetRelations:
                continue
            rel2pair[rel].add((subj.lower(), obj.lower()))
    print("Collecting relation to pair finished")
    return rel2pair

def compute_accuracy_novelty(path: str = "./samples.txt", plot=False):
    correct_1 = 0
    correct_2 = 0
    correct_3 = 0
    correct_4 = 0
    correct_5 = 0
    novel_1 = 0
    novel_2 = 0
    novel_3 = 0
    novel_4 = 0
    novel_5 = 0
    total = 0
    all_correctness = []
    all_novelty = []
    with open(path, 'r') as f:
        lines = f.readlines()
        for line in lines:
            line = line.strip()
            subj = re.search(r"head: (.*?),", line).groups()[0].lower()
            relation = re.search(r"relation: (.*?),", line).groups()[0]
            predictions = re.search(r"predictions: \[(.*?)\]", line)
            predictions = predictions.groups()[0].split(',')
            predictions = list(map(lambda s: s.strip(" ").strip().strip("'").lower(), predictions))
            assert len(predictions) == 5
            scores, _ = line.strip().split("head")
            scores = list(map(int, scores.split()))
            c_1, n_1, c_2, n_2, c_3, n_3, c_4, n_4, c_5, n_5 = scores
            all_correctness.extend([c_1, c_2, c_3, c_4, c_5])
            all_novelty.extend([n_1,n_2, n_3, n_4, n_5])
            total += 1
            if c_5:
                correct_5 += 1
                if n_5 and not (subj, predictions[4]) in rel2pair[relation]:
                    novel_5 += 1
            if c_4:
                correct_4 += 1
                correct_5 += 1
                if n_4 and not (subj, predictions[3]) in rel2pair[relation]:
                    novel_4 += 1
                    novel_5 += 1
            if c_3:
                correct_3 += 1
                correct_4 += 1
                correct_5 += 1
                if n_3 and not (subj, predictions[2]) in rel2pair[relation]:
                    novel_3 += 1
                    novel_4 += 1
                    novel_5 += 1
            if c_2:
                correct_2 += 1
                correct_3 += 1
                correct_4 += 1
                correct_5 += 1
                if n_2 and not (subj, predictions[1]) in rel2pair[relation]:
                    novel_2 += 1
                    novel_3 += 1
                    novel_4 += 1
                    novel_5 += 1
            if c_1:
                correct_1 += 1
                correct_2 += 1
                correct_3 += 1
                correct_4 += 1
                correct_5 += 1
                if n_1 and not (subj, predictions[0]) in rel2pair[relation]:
                    novel_1 += 1
                    novel_2 += 1
                    novel_3 += 1
                    novel_4 += 1
                    novel_5 += 1
    precision_1 = correct_1  / (1*total) * 100
    precision_2 = correct_2  / (2*total) * 100
    precision_3 = correct_3  / (3*total) * 100
    precision_4 = correct_4  / (4*total) * 100
    precision_5 = correct_5  / (5*total) * 100
    precisions = [precision_1, precision_2, precision_3, precision_4, precision_5]
    precisions = list(map(lambda x: round(x, 1), precisions))
    novelty_1 = novel_1 / correct_1 * 100
    novelty_2 = novel_2 / correct_2 * 100
    novelty_3 = novel_3 / correct_3 * 100
    novelty_4 = novel_4 / correct_4 * 100
    novelty_5 = novel_5 / correct_5  * 100
    noveltys = [novelty_1, novelty_2, novelty_3, novelty_4, novelty_5]
    noveltys = list(map(lambda x: round(x, 1), noveltys))
    harmonic_means = []
    for i in range(5):
        harmonic_means.append(F_beta_score(1, precisions[i], noveltys[i]))
    print("Precision@1: {}, Novelty@1: {}".format(precision_1, novelty_1))
    print("Precision@2: {}, Novelty@2: {}".format(precision_2, novelty_2))
    print("Precision@3: {}, Novelty@3: {}".format(precision_3, novelty_3))
    print("Precision@4: {}, Novelty@4: {}".format(precision_4, novelty_4))
    print("Precision@5: {}, Novelty@5: {}".format(precision_5, novelty_5))

    for i in range(5):
        print(harmonic_means[i])
    if plot:
        # plot
        paper_rc = {'lines.linewidth': 2, 'lines.markersize': 13}                  
        sns.set_context("paper", rc = paper_rc)
        sns.set_style('darkgrid')
        plt.figure(figsize=(8, 5))
        plt.plot(noveltys, precisions, label='Top@K', ms=10, marker='o')
        plt.plot()
        plt.xlabel("Novelty(%)", fontsize=17)
        plt.ylabel("Precision(%)", fontsize=17)
        plt.ylim(56, 81)
        orders = ['Top@1', 'Top@2', 'Top@3', 'Top@4', 'Top@5']
        cnt = 0
        for x1, y1 in zip(noveltys, precisions):
            plt.text(x1+0.3, y1 + 0.8, orders[cnt], ha='right', va='center', fontsize=8, rotation=0)
            cnt += 1
        plt.legend()
        plt.savefig("./figures/precision_novelty.png", dpi=360)
    return noveltys, precisions, all_correctness, all_novelty

def merge_all_human_eval(paths: list):
    num_annotators = len(paths)
    total_noveltys = [0.0] * 5
    total_precisions = [0.0] * 5
    all_correctness_annotators = []
    all_novelty_annotators = []
    for path in paths:
        noveltys, precisions, all_correctness, all_novelty = compute_accuracy_novelty(path, plot=False)
        all_correctness_annotators.append(all_correctness)
        all_novelty_annotators.append(all_novelty)
        for i in range(5):
            total_noveltys[i] += noveltys[i]
        for i in range(5):
            total_precisions[i] += precisions[i]
    total_noveltys = [score / num_annotators for score in total_noveltys]
    total_precisions = [score / num_annotators for score in total_precisions]
    print("Averaged Precision@1: {}, Novelty@1: {}".format(total_precisions[0], total_noveltys[0]))
    print("Averaged Precision@2: {}, Novelty@2: {}".format(total_precisions[1], total_noveltys[1]))
    print("Averaged Precision@3: {}, Novelty@3: {}".format(total_precisions[2], total_noveltys[2]))
    print("Averaged Precision@4: {}, Novelty@4: {}".format(total_precisions[3], total_noveltys[3]))
    print("Averaged Precision@5: {}, Novelty@5: {}".format(total_precisions[4], total_noveltys[4]))
    paper_rc = {'lines.linewidth': 2, 'lines.markersize': 13}                  
    sns.set_context("paper", rc = paper_rc)
    sns.set_style('darkgrid')
    params = {'legend.fontsize': 16,
          'legend.handlelength': 2}
    plt.rcParams.update(params)
    plt.figure(figsize=(7, 5.5))
    plt.plot(total_noveltys, total_precisions, label='Top@K', ms=10, marker='o')
    plt.plot()
    plt.tick_params(labelsize=15)
    plt.xlabel("Novelty(%)", fontsize=15)
    plt.ylabel("Precision(%)", fontsize=15)
    plt.ylim(52, 81)
    orders = ['1', '2', '3', '4', '5']
    cnt = 0
    for x1, y1 in zip(total_noveltys, total_precisions):
        plt.text(x1+0.35, y1 + 1.2, orders[cnt], ha='right', va='center', fontsize=15, rotation=0)
        cnt += 1
    plt.legend()
    plt.savefig(f"./figures/precision_novelty_{num_annotators}.png", dpi=360)
    print("precision-novelty plot saved!")

    tmp_score = krippendorff.alpha(all_correctness_annotators)
    print(tmp_score)
    tmp_score = krippendorff.alpha(all_novelty_annotators)
    print(tmp_score)
    # cohen_kappa_score_precision = cohen_kappa_score(all_correctness_annotators[0], all_correctness_annotators[1])
    # cohen_kappa_score_novelty = cohen_kappa_score(all_novelty_annotators[0], all_novelty_annotators[1])
    # print("Cohen's kappa for precision: {}".format(cohen_kappa_score_precision))
    # print("Cohen's kappa for novelty: {}".format(cohen_kappa_score_novelty))


if __name__ == "__main__":
    # get the true triples from training/dev set
    rel2pair = collect_train_dev_triples()

    # compute_accuracy_novelty("./samples.txt", plot=True)
    merge_all_human_eval(["./samples.txt", "./samples2.txt"])