from tqdm import tqdm
import numpy as np
import sys
sys.path.append("m2score")
from m2score.m2scorer import load_annotation
from m2score.util import smart_open
from m2score.levenshtein import batch_multi_pre_rec_f1, batch_multi_pre_rec_f1_sent
from errant_score import batch_multi_pre_rec_f1_errant, batch_multi_pre_rec_f1_sent_errant, errant_load_annotation
from bart_score import BARTScorer
from bert_score import BERTScorer


def compute_sentm2(m2_file, hyp_file, references, scorer, args):
    source_sentences, gold_edits = load_annotation(m2_file)
    fin = smart_open(hyp_file, 'r')
    system_sentences = [line.strip() for line in fin.readlines()]
    fin.close()

    score_lst = []
    for hyp, src, refs, golds in tqdm(zip(system_sentences, source_sentences, references, gold_edits)):
        f1 = batch_multi_pre_rec_f1_sent([hyp], [src], [golds], [refs], scorer, args, beta=args.beta)[-1]
        score_lst.append(f1)

    return sum(np.array(score_lst)) / len(system_sentences), score_lst


def compute_m2(m2_file, hyp_file, references, scorer, args):
    source_sentences, gold_edits = load_annotation(m2_file)
    fin = smart_open(hyp_file, 'r')
    system_sentences = [line.strip() for line in fin.readlines()]
    fin.close()

    score = batch_multi_pre_rec_f1(system_sentences, source_sentences, gold_edits, references, scorer, args, beta=args.beta)[-1]
    return score, None


def compute_senterrant(m2_file, hyp_file, hyp_n, references, scorer, args):
    source_sentences, gold_edits, sys_edits = errant_load_annotation(hyp_file, m2_file)
    sys_file = f"data/conll14/hyp/{hyp_n}"
    fin = smart_open(sys_file, 'r')
    system_sentences = [line.strip() for line in fin.readlines()]
    fin.close()

    score_lst = []
    for hyp, src, refs, sys, golds in tqdm(zip(system_sentences, source_sentences, references, sys_edits, gold_edits)):
        f1 = batch_multi_pre_rec_f1_sent_errant([hyp], [src], [sys], [golds], [refs], scorer, args)[-1]
        score_lst.append(f1)

    return sum(np.array(score_lst)) / len(system_sentences), score_lst


def compute_errant(m2_file, hyp_file, hyp_n, references, scorer, args):
    source_sentences, gold_edits, sys_edits = errant_load_annotation(hyp_file, m2_file)
    sys_file = f"data/conll14/hyp/{hyp_n}"
    fin = smart_open(sys_file, 'r')
    system_sentences = [line.strip() for line in fin.readlines()]
    fin.close()

    score = batch_multi_pre_rec_f1_errant(system_sentences, source_sentences, sys_edits, gold_edits, references, scorer, args)[-1]
    return score, None


def get_plm_scorer(args, references=None):
    assert args.scorer in ["bertscore", "bartscore"]
    if args.scorer == "bertscore":
        assert references
        scorer = BERTScorer(device=args.device, model_type=args.model_type,
                            lang="en", rescale_with_baseline=True,
                            idf=True, idf_sents=references)
    elif args.scorer == "bartscore":
        scorer = BARTScorer(device=args.device, checkpoint=f"facebook/{args.model_type}")
    return scorer