from __future__ import print_function, unicode_literals, division

import numpy as np

from nltk.tokenize import sent_tokenize
from rouge_score import rouge_scorer
from bert_score import score
from common.bart_score import BARTScorer
from scipy.stats import pearsonr
from common.summary_processing import pre_rouge_processing



def map_language(args):
    if args.languages[0] == "fr":
        return "french"
    elif args.languages[0] == "de":
        return "german"
    elif args.languages[0] == "es":
        return "spanish"
    elif args.languages[0] == "ru":
        return "russian"
    elif args.languages[0] == "trk":
        return "turkish"


def overall_eval(val_texts, val_summaries, val_labels, args):
    # 1 - ROUGE
    r1_true = -1
    r2_true = -1
    rl_true = -1
    if args.eval_rouge:
        r1_true, r2_true, rl_true = rouge_eval("true labels", val_summaries, val_texts, val_labels, args, show_summaries = True)
    # 2 - BERTScore
    if args.eval_bertscore:
        bertscore_eval(val_summaries, val_labels, args)
    # 3 - BARTScore
    if args.eval_bartscore:
        bartscore_eval(val_summaries, val_labels, args)
    # 4 - Copying
    if args.eval_ngram_overlap:
        ngram_overlap_eval(val_summaries, val_texts, args)
    # 5 - Abstractiveness
    if args.eval_new_ngrams:
        new_ngram_eval(val_summaries, val_texts, args)
    # 6 - Overlap with source
    if args.eval_rouge_text:
        r1_text, r2_text, rl_text = rouge_eval("source", val_summaries, val_texts, val_texts, args)
        if args.check_correlation:
            r1_p = pearsonr(r1_true, r1_text)[0]
            r2_p = pearsonr(r2_true, r2_text)[0]
            rl_p = pearsonr(rl_true, rl_text)[0]
            print("Pearson correlations between ROUGE w true labels and ROUGE w source: {:.4f} / {:.4f} / {:.4f}".format(r1_p, r2_p, rl_p))

    return r1_true, r2_true, rl_true


################################################################ 1 - ROUGE


def rouge_eval(text, val_summaries, val_texts, val_labels, args, show_summaries = False):
    print("\n", "*"*10, "1 - ROUGE evaluation with {}".format(text), "*"*10)
    scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeLsum'], use_stemmer=args.stemmer)
    all_r1s = []
    all_r2s = []
    all_rls = []
    for i in range(len(val_summaries)):
        if show_summaries and i < args.n_show_summaries:
            print("*" * 50)
            print("\nData point: {} / {}".format(i+1, len(val_summaries)))
            print("\nText:")
            print(val_texts[i].replace("\n", " "))
            print("\nLEAD-3:")
            sents = sent_tokenize(val_texts[i])
            lead_3 = " ".join(sents[:3])
            print(lead_3.replace("\n", " "))
            print("\nPredicted summary:")
            print(val_summaries[i].replace("\n", " "))
            print("\nGround-truth summary:")
            print(val_labels[i])

        summary = val_summaries[i]
        summary = pre_rouge_processing(summary, args)
        label = val_labels[i]
        r1, r2, rl = get_rouge_scores(summary, label, scorer, args)
        all_r1s.append(r1)
        all_r2s.append(r2)
        all_rls.append(rl)
    all_r1s = 100 * np.array(all_r1s)
    all_r2s = 100 * np.array(all_r2s)
    all_rls = 100 * np.array(all_rls)
    mean_r1 = np.mean(all_r1s)
    mean_r2 = np.mean(all_r2s)
    mean_rl = np.mean(all_rls)
    mean_r = (mean_r1 + mean_r2 + mean_rl) / 3
    print("Mean R: {:.4f}, R-1: {:.4f} (var: {:.4f}), R-2: {:.4f} (var: {:.4f}), R-L: {:.4f} (var: {:.4f})".format(
        mean_r, mean_r1, np.std(all_r1s), mean_r2, np.std(all_r2s), mean_rl, np.std(all_rls)))

    return all_r1s, all_r2s, all_rls


def get_rouge_scores(summary, label, scorer, args):
    if args.rouge_to_use == "rouge_score":
        rouge_scores = scorer.score(label, summary)
        r1 = rouge_scores["rouge1"].fmeasure
        r2 = rouge_scores["rouge2"].fmeasure
        rl = rouge_scores["rougeLsum"].fmeasure

    return r1, r2, rl


################################################################ 2 - BERTScore


def bertscore_eval(val_summaries, val_labels, args, verbose=True):
    print("\n", "*" * 10, "2 - BERTScore evaluation", "*" * 10)
    p, r, f1 = score(val_summaries, val_labels, lang='en', verbose=verbose)
    mean_f1 = 100 * f1.mean()
    print("Mean BERTScore F1: {:.2f}".format(mean_f1))
    return f1


################################################################ 3 - BARTScore


def bartscore_eval(val_summaries, val_labels, args):
    print("\n", "*" * 10, "3 - BARTScore evaluation", "*" * 10)
    bart_scorer = BARTScorer(device = args.device, checkpoint = 'facebook/bart-large-cnn')
    bartscore_scores = bart_scorer.score(val_labels, val_summaries)
    m_bartscore = np.mean(np.array(bartscore_scores))
    print("Mean BARTScore: {:.2f}".format(m_bartscore))
    return np.array(bartscore_scores)


################################################################ 4 - COPYING


def ngram_overlap_eval(predicted_sents, src_sents, args):
    print("\n", "*"*10, "4 - Copying evalation / N-gram overlap", "*"*10)
    unigram_overlap = []
    bigram_overlap = []
    trigram_overlap = []
    for i in range(len(predicted_sents)):
        src_words = src_sents[i].lower().split()
        src_bigrams = [[src_words[j], src_words[j + 1]] for j in range(len(src_words) - 1)]
        src_trigrams = [[src_words[j], src_words[j + 1], src_words[j + 2]] for j in range(len(src_words) - 2)]
        predicted_words = predicted_sents[i].lower().split()
        unigrams = 0
        bigrams = 0
        trigrams = 0
        for j in range(len(predicted_words)):
            if predicted_words[j] in src_words:
                unigrams += 1
            if j < len(predicted_words) - 1:
                if [predicted_words[j], predicted_words[j + 1]] in src_bigrams:
                    bigrams += 1
            if j < len(predicted_words) - 2:
                if [predicted_words[j], predicted_words[j + 1], predicted_words[j + 2]] in src_trigrams:
                    trigrams += 1
        if len(predicted_words) > 0:
            unigram_overlap.append(unigrams / len(predicted_words))
        if len(predicted_words) > 1:
            bigram_overlap.append(bigrams / (len(predicted_words) - 1))
        if len(predicted_words) > 2:
            trigram_overlap.append(trigrams / (len(predicted_words) - 2))
    unigram_overlap = np.array(unigram_overlap)
    m_uni = 100 * np.mean(unigram_overlap)
    bigram_overlap = np.array(bigram_overlap)
    m_bi = 100 * np.mean(bigram_overlap)
    trigram_overlap = np.array(trigram_overlap)
    m_tri = 100 * np.mean(trigram_overlap)
    print("Unigram overlap: {:.2f}, bigram: {:.2f}, trigram: {:.2f}".format(m_uni, m_bi, m_tri))


################################################################ 5 - Abstractiveness


def new_ngram_eval(val_summaries, val_texts, args):
    print("\n", "*"*10, "5 - Abstractiveness / New n-gram", "*"*10)
    new_unigrams = []
    new_bigrams = []
    new_trigrams = []
    new_quadrigrams = []
    for i in range(len(val_summaries)):
        src_words = val_texts[i].lower().split()
        src_bigrams = [[src_words[j], src_words[j + 1]] for j in range(len(src_words) - 1)]
        src_trigrams = [[src_words[j], src_words[j + 1], src_words[j + 2]] for j in range(len(src_words) - 2)]
        src_quadrigrams = [[src_words[j], src_words[j + 1], src_words[j + 2], src_words[j + 3]] for j in range(len(src_words) - 3)]
        predicted_words = val_summaries[i].lower().split()
        unigrams = 0
        bigrams = 0
        trigrams = 0
        quadrigrams = 0
        for j in range(len(predicted_words)):
            if not(predicted_words[j] in src_words):
                unigrams += 1
            if j < len(predicted_words) - 1:
                bigram = [predicted_words[j], predicted_words[j + 1]]
                if not(bigram in src_bigrams):
                    bigrams += 1
            if j < len(predicted_words) - 2:
                trigram = [predicted_words[j], predicted_words[j + 1], predicted_words[j + 2]]
                if not(trigram in src_trigrams):
                    trigrams += 1
            if j < len(predicted_words) - 3:
                quadrigram = [predicted_words[j], predicted_words[j + 1], predicted_words[j + 2], predicted_words[j + 3]]
                if not(quadrigram in src_quadrigrams):
                    quadrigrams += 1
        if len(predicted_words) > 0:
            new_unigrams.append(unigrams / len(predicted_words))
        if len(predicted_words) > 1:
            new_bigrams.append(bigrams / (len(predicted_words) - 1))
        if len(predicted_words) > 2:
            new_trigrams.append(trigrams / (len(predicted_words) - 2))
        if len(predicted_words) > 3:
            new_quadrigrams.append(quadrigrams / (len(predicted_words) - 3))
    new_unigrams = np.array(new_unigrams)
    m_uni = 100 * np.mean(new_unigrams)
    new_bigrams = np.array(new_bigrams)
    m_bi = 100 * np.mean(new_bigrams)
    new_trigrams = np.array(new_trigrams)
    m_tri = 100 * np.mean(new_trigrams)
    new_quadrigrams = np.array(new_quadrigrams)
    m_quadri = 100 * np.mean(new_quadrigrams)
    print("New unigrams: {:.2f}, bigrams: {:.2f}, trigrams: {:.2f}, quadrigrams: {:.2f}".format(m_uni, m_bi, m_tri, m_quadri))
