import json
import sys
import os

sys.path.insert(1, os.path.dirname(os.path.dirname(__file__)))

from bert_score import score
from tqdm import tqdm

from evalpackage import Evaluator
from evalpackage.mlmscore import MLMScore
from evalpackage.clmscore import CLMScore
from evalpackage.qrelscore import QRelScore

def bert_score_sentence_level(hyp, ref):
    P, R, F1 = score([ hyp ], [ ref ], lang='en', rescale_with_baseline = True)

    return F1.detach().cpu().numpy().tolist()


def bert_score_corpus_level(hyps, refs):
    P, R, F1 = score(hyps, refs, lang='en', rescale_with_baseline = True)

    return F1.detach().cpu().numpy().tolist()


def parse_squad1_input(dataset_filename):
    with open(dataset_filename) as f:
        train_set = json.load(f)
    data_file = train_set

    data_dict = { }

    for doc in data_file['data']:
        title = doc['title']
        for par in doc['paragraphs']:
            context = par['context']
            for qa in par['qas']:
                qid = qa['id']
                gold_question = qa['question']
                data_dict[qid] = {
                    'context': context,
                    'gold_question': gold_question,
                }

    return data_dict


def main_qrellr_baseline():
    dataset_filename = '/path/to/home/to/your/project/squad1/dev-v1.1.json'
    qrellr_filename = '/path/to/home/to/your/project/qgbase/summary/qrellr-baseline-analysis/squad1-dev-qrellr.json'

    data_dict = parse_squad1_input(dataset_filename)
    mlmscorer = MLMScore()
    for key, qvalue in tqdm(data_dict.items()):
        gold_question = qvalue['gold_question']
        context = qvalue['context']

        qrellr = mlmscorer.forward_pass_score(gold_question, context)
        data_dict[key].update({
            'qrellr': qrellr
        })

    with open(qrellr_filename, 'w') as f:
        json.dump(data_dict, f, indent = 4)

def main_qrelga_baseline():
    dataset_filename = '/path/to/home/to/your/project/squad1/dev-v1.1.json'
    qrelga_filename = '/path/to/home/to/your/project/qgbase/summary/qrelga-baseline-analysis/squad1-dev-qrelga.json'

    data_dict = parse_squad1_input(dataset_filename)
    clm_scorer = CLMScore()
    for key, qvalue in tqdm(data_dict.items()):
        gold_question = qvalue['gold_question']
        context = qvalue['context']

        qrelga = clm_scorer.diff_score(context, gold_question)
        data_dict[key].update({
            'qrelga': qrelga
        })

    with open(qrelga_filename, 'w') as f:
        json.dump(data_dict, f, indent = 4)

def main_qrelscore_baseline():
    dataset_filename = '/path/to/home/to/your/project/squad1/dev-v1.1.json'
    qrelscore_filename = '/path/to/home/to/your/project/qgbase/summary/qrelscore-baseline-analysis/squad1-dev-qrelscore.json'

    data_dict = parse_squad1_input(dataset_filename)
    qrel_scorer = QRelScore()
    for key, qvalue in tqdm(data_dict.items()):
        gold_question = qvalue['gold_question']
        context = qvalue['context']

        gts = [ context ]
        res = [ gold_question ]
        qrelscore = qrel_scorer.compute_score_flatten(gts, res)

        data_dict[key].update({
            'qrelscore': qrelscore
        })

    with open(qrelscore_filename, 'w') as f:
        json.dump(data_dict, f, indent = 4)


if __name__ == '__main__':
    main_qrellr_baseline()
    main_qrelga_baseline()
    main_qrelscore_baseline()