import argparse
import sys
import numpy as np 
import torch
import pickle
import scipy.special

sys.path.append("xxx")

from tqdm import tqdm
from rouge_score import rouge_scorer
from scipy.stats import pearsonr
from scipy.stats import ttest_ind

from common.evaluation import *



# 0
def metrics_correlation(scored_summaries, args):
    all_preds = []
    n_metrics = len(args.scoring_methods)
    for j in range(n_metrics):
        preds = [scored_summaries[i][1][j][0] for i in range(len(scored_summaries))]
        all_preds.append(np.array(preds))
    for j in range(n_metrics):
        for k in range(j+1, n_metrics):
            corr_j_k, _ = pearsonr(all_preds[j], all_preds[k])
            print("Corr between metric {} and {}: {:.4f}".format(j, k, corr_j_k))


# 3
def qualitative_samples(val_texts, scored_summaries, val_preds_idx, val_overall_predictions, val_labels, args):
    scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeLsum'], use_stemmer=args.stemmer)

    # sample
    idx = np.random.permutation(len(val_texts))[:500]
    sample_val_texts = [val_texts[i] for i in idx]
    sample_scored_summaries = [scored_summaries[i] for i in idx]
    sample_val_preds_idx = [val_preds_idx[i] for i in idx]
    sample_val_labels = [val_labels[i] for i in idx]
    sample_val_overall_predictions = [val_overall_predictions[i] for i in idx]
    print(len(sample_val_texts), len(sample_scored_summaries), len(sample_val_preds_idx), len(sample_val_labels), len(sample_val_overall_predictions))

    # score top beam baseline + reranked summary
    top_summaries = []
    top_mean_rouges = []
    pred_summaries = []
    pred_mean_rouges = []
    for i in tqdm(range(len(sample_val_texts))):
        top_summary = sample_scored_summaries[i][0][0]
        top_summary = pre_rouge_processing(top_summary, args)
        top_summaries.append(top_summary)
        pred_summary = sample_scored_summaries[i][0][sample_val_preds_idx[i]]
        pred_summary = pre_rouge_processing(pred_summary, args)
        pred_summaries.append(pred_summary) 
        label = sample_val_labels[i]
        top_r1, top_r2, top_rl = get_rouge_scores(top_summary, label, scorer, args)
        top_mean_r = 100 * (top_r1 + top_r2 + top_rl) / 3
        top_mean_rouges.append(top_mean_r)
        pred_r1, pred_r2, pred_rl = get_rouge_scores(pred_summary, label, scorer, args)
        pred_mean_r = 100 * (pred_r1 + pred_r2 + pred_rl) / 3
        pred_mean_rouges.append(pred_mean_r)
    top_mean_rouges = np.array(top_mean_rouges)
    pred_mean_rouges = np.array(pred_mean_rouges)

    m_r_gain = np.mean(pred_mean_rouges - top_mean_rouges)
    print("Mean mean ROUGE improvement: {:.4f}".format(m_r_gain))
    sort_idx = np.argsort(pred_mean_rouges - top_mean_rouges)[::-1]
    for i in range(10):
        print("\n", "*"*50, "SOURCE:")
        print(sample_val_texts[sort_idx[9-i]])
        print("*"*20, "LABEL:")
        print(sample_val_labels[sort_idx[9-i]].replace("\n", " "))
        if args.evaluate_qualitative_full_predictions:
            print("*"*20, "CANDIDATES:")
            model_ranks = np.array(sample_val_overall_predictions[sort_idx[9-i]]).argsort().argsort()
            summaries_i = []
            mean_rs = []
            r1s = []
            r2s = []
            rls = []
            for j in range(len(scored_summaries[0][0])):
                summary_j = sample_scored_summaries[sort_idx[9-i]][0][j]
                summary_j = pre_rouge_processing(summary_j, args)
                summaries_i.append(summary_j)
                r1_j, r2_j, rl_j = get_rouge_scores(summary_j, sample_val_labels[sort_idx[9-i]], scorer, args)
                r1_j, r2_j, rl_j = 100 * r1_j, 100 * r2_j, 100 * rl_j
                mean_r_j = (r1_j + r2_j + rl_j) / 3
                mean_rs.append(mean_r_j)
                r1s.append(r1_j)
                r2s.append(r2_j)
                rls.append(rl_j)
            label_ranks = np.array(mean_rs).argsort().argsort()
            for j in range(len(scored_summaries[0][0])):
                print("\nCANDIDATE {} / {}, MODEL SCORE: {:.4f}, MODEL RANK: {} // mean R: {:.4f},  R-1: {:.4f},  R-2: {:.4f}, R-L: {:.4f}, mean R rank: {}".format(
                    j+1, len(sample_scored_summaries[0][0]), sample_val_overall_predictions[sort_idx[9-i]][j], len(sample_scored_summaries[0][0]) - model_ranks[j], 
                    mean_rs[j], r1s[j], r2s[j], rls[j], len(sample_scored_summaries[0][0]) - label_ranks[j]
                ))
                print(summaries_i[j].replace("\n", " "))
        else:
            if top_mean_rouges[sort_idx[9-i]] >= 35:
                print("*"*20, "TOP SUMMARY (mean ROUGE: {:.4f}):".format(top_mean_rouges[sort_idx[9-i]]))
                print(top_summaries[sort_idx[9-i]].replace("\n", " "))
                print("*"*20, "RERANKED SUMMARY (mean ROUGE: {:.4f}):".format(pred_mean_rouges[sort_idx[9-i]]))
                print(pred_summaries[sort_idx[9-i]].replace("\n", " "))


# 4
def reranker_on_label(val_dataset, val_labels, model, args):
    # include the label as a candidate
    for i in tqdm(range(len(val_dataset.scored_summaries))):
        val_dataset.scored_summaries[i][0].append(val_labels[i])
        for j in range(len(val_dataset.scored_summaries[0][1])):
            val_dataset.scored_summaries[i][1][j].append(100)
    val_loader = torch.utils.data.DataLoader(val_dataset, batch_size = args.inference_bs, shuffle = False)

    n_candidates = len(val_dataset.scored_summaries[0][0])
    print("Number of summary candidates: {}".format(n_candidates))
    # inference
    val_texts = []
    val_labels = []
    val_preds_idx = []
    val_predictions = []
    val_overall_predictions = []
    for i, batch in tqdm(enumerate(val_loader)):
        model.zero_grad()

        mode = batch["mode"]
        batch_texts = batch["text"]
        val_texts += batch_texts
        batch_labels = batch["label"]
        val_labels += batch_labels

        text_ids = batch["text_input_ids"].to(model.pretrained_model.device)
        text_mask = batch["text_attn_mask"].to(model.pretrained_model.device)
        cand_ids = batch["cand_input_ids"].to(model.pretrained_model.device)
        cand_mask = batch["cand_attn_mask"].to(model.pretrained_model.device)
        text_and_summaries_ids = batch["text_and_summaries_input_ids"].to(model.pretrained_model.device)
        text_and_summaries_mask = batch["text_and_summaries_attn_mask"].to(model.pretrained_model.device)
        scores = batch["scores"]

        with torch.no_grad():
            output = model(mode, text_ids, text_mask, text_and_summaries_ids, text_and_summaries_mask, None, None, scores)
            predictions_idx = output["total_predictions_idx"]
            val_preds_idx += predictions_idx
            val_overall_predictions += output["overall_predictions"]
    val_preds_idx = np.array(val_preds_idx)

    # label accuracy
    label_acc = 100 * np.mean(val_preds_idx == (n_candidates-1))
    print("Label accuracy (% that the label is ranked at the top): {:.4f}".format(label_acc))

    # label rank
    val_overall_predictions = np.array(val_overall_predictions)
    all_ranks = []
    for i in range(len(val_overall_predictions)):
        ranks = np.array(val_overall_predictions[i]).argsort().argsort()
        label_rank = n_candidates - ranks[-1]
        all_ranks.append(label_rank)
    m_rank = np.mean(all_ranks)
    print("Label mean rank: {:.4f}".format(m_rank))


# 5
def ranking_correlations(scored_summaries, val_overall_predictions, args):
    n_candidates = len(val_overall_predictions[0])
    # 1 - correlation between beam rank and true predictions
    all_beam_ranks = []
    all_reranker_ranks = []
    all_true_ranks = []
    for i in range(len(val_overall_predictions)):
        all_beam_ranks += list(1+np.arange(n_candidates))
        reranker_preds = val_overall_predictions[i]
        reranker_ranks = n_candidates - np.array(reranker_preds).argsort().argsort()
        all_reranker_ranks += list(reranker_ranks)
        true_preds = scored_summaries[i][1][0]
        true_ranks = n_candidates - np.array(true_preds).argsort().argsort()
        all_true_ranks += list(true_ranks)
    all_beam_ranks = np.array(all_beam_ranks)
    all_reranker_ranks = np.array(all_reranker_ranks)
    all_true_ranks = np.array(all_true_ranks)
    beam_corr, _ = pearsonr(all_beam_ranks, all_true_ranks)
    print("Pearson corr coef between beam ranks and true ranks: {:.4f}".format(beam_corr))
    reranker_corr, _ = pearsonr(all_reranker_ranks, all_true_ranks)
    print("Pearson corr coef between reranker ranks and true ranks: {:.4f}".format(reranker_corr))


# 6
def length_evaluation(val_texts, scored_summaries, val_preds_idx, val_labels, args):
    all_lengths = np.array([len(val_labels[i].split()) for i in range(len(val_labels))])
    n_bins = 5
    all_base_metrics = []
    all_oracle_metrics = []
    all_reranked_metrics = []
    all_lows = []
    all_highs = []
    for i in range(n_bins):
        low_thresh = np.percentile(all_lengths, i * int(100 / n_bins))
        high_thresh = np.percentile(all_lengths, (i+1) * int(100 / n_bins))
        idx = (all_lengths >= low_thresh) * (all_lengths < high_thresh)
        idx = np.arange(len(all_lengths))[idx]
        val_texts_i = [val_texts[j] for j in idx]
        scored_summaries_i = [scored_summaries[j] for j in idx]
        val_preds_idx_i = [val_preds_idx[j] for j in idx]
        val_labels_i = [val_labels[j] for j in idx]
        val_preds_i = []
        for j in range(len(val_preds_idx_i)):
            val_preds_i.append(scored_summaries_i[j][0][val_preds_idx_i[j]])
        print("\n", "*"*70, "Summaries of length between {} and {}, # summaries: {}".format(low_thresh, high_thresh, len(idx)))
        # base performance
        mean_metrics = []
        for j in range(len(scored_summaries_i[0][1])):
            mean_metrics.append([])
        for k in range(len(scored_summaries_i)):
            for j in range(len(mean_metrics)):
                mean_metrics[j].append(scored_summaries_i[k][1][j][0])
        mean_metrics = [np.mean(mean_metrics[j]) for j in range(len(mean_metrics))]
        print("\nBase performance:")
        print(mean_metrics)
        all_base_metrics.append(mean_metrics)
        # oracle performance
        oracle_metrics = []
        for j in range(len(scored_summaries_i[0][1])):
            oracle_metrics.append([])
        for k in range(len(scored_summaries_i)):
            for j in range(len(mean_metrics)):
                oracle_metrics[j].append(max(scored_summaries_i[k][1][j]))
        oracle_metrics = [np.mean(oracle_metrics[j]) for j in range(len(oracle_metrics))]
        print("\nOracle performance:")
        print(oracle_metrics)
        all_oracle_metrics.append(oracle_metrics)
        # reranking performance
        print("\nReranking performance:")
        r1, r2, rl = overall_eval(val_texts_i, val_preds_i, val_labels_i, args)
        all_reranked_metrics.append([np.mean(r1), np.mean(r2), np.mean(rl)])
        all_lows.append(int(low_thresh))
        all_highs.append(int(high_thresh))
    for i in range(len(all_lows)):
        base = ["{0:.2f}".format(j) for j in all_base_metrics[i]]
        mean_base = np.mean(all_base_metrics[i])
        oracle = ["{0:.2f}".format(j) for j in all_oracle_metrics[i]]
        mean_oracle = np.mean(all_oracle_metrics[i])
        oracle_diff = 100 * (mean_oracle - mean_base) / mean_base
        reranked = ["{0:0.2f}".format(j) for j in all_reranked_metrics[i]]
        mean_reranked = np.mean(all_reranked_metrics[i])
        mean_diff = mean_reranked - mean_base
        relative_gain = 100 * mean_diff / (mean_oracle - mean_base)
        print("Summaries between length {} and {}, base mean ROUGE: {:.4f} // oracle mean ROUGE: {:.4f} (max gain: {:.4f} or {:.4f}%) // reranked mean ROUGE: {:.4f} // absolute gain: {:.4f}, relative gain (%): {:.4f}".format(
            all_lows[i], all_highs[i], mean_base, mean_oracle, mean_oracle - mean_base, oracle_diff, mean_reranked, mean_diff, relative_gain))


# 7
def recall_evaluation(scored_summaries, val_overall_predictions, args):
    threshs = [1, 2, 3, 4, 5]
    n_candidates = len(val_overall_predictions[0])
    print("# Candidates: {}".format(n_candidates))

    # naive recall
    print("\nNaive recall:")
    base_recalls = []
    for k in range(len(args.generation_methods)):
        base_recalls.append([])
    reranked_recalls = []
    for t in threshs:
        reranked_recalls.append([])
    for k in range(len(args.generation_methods)):
        for j in threshs:
            base_recalls[k].append([])
    for i in tqdm(range(len(val_overall_predictions))):
        # base
        for k in range(len(args.generation_methods)):
            labels_i_k = np.zeros(args.num_beams)
            for j in range(len(args.scoring_methods)):
                labels_i_k += np.array(scored_summaries[i][1][j])[k*args.num_beams:(k+1)*args.num_beams]
            best_idx = np.argmax(labels_i_k)
            for j in range(len(threshs)):
                base_recalls[k][j].append(int((best_idx+1) <= threshs[j]))
        # reranked
        preds_i = np.array(val_overall_predictions[i])
        ranks_i = len(preds_i) - preds_i.argsort().argsort()
        labels_i = np.zeros(len(preds_i))
        for j in range(len(args.scoring_methods)):
            labels_i += np.array(scored_summaries[i][1][j])
        best_idx = np.argmax(labels_i)
        for j in range(len(threshs)):
            reranked_recalls[j].append(int(ranks_i[best_idx] <= threshs[j]))
    for j in range(len(threshs)):
        recall_bases = []
        for k in range(len(args.generation_methods)):
            recall_base_k = 100 * np.mean(base_recalls[k][j])
            recall_bases.append(recall_base_k)
        recall_bases = ["{:.4f}".format(i) for i in recall_bases]
        recall_reranker = 100 * np.mean(reranked_recalls[j])
        print("Recall at {}: Baseline: {:.2f} // Base model(s): {} // Reranker model: {:.2f}".format(threshs[j], 100 * threshs[j] / n_candidates, recall_bases, recall_reranker))

    # real recall
    print("\nReal recall:")
    baseline_recalls = []
    for t in threshs:
        baseline_recalls.append([])
    base_recalls = []
    for k in range(len(args.generation_methods)):
        base_recalls.append([])
    reranked_recalls = []
    for t in threshs:
        reranked_recalls.append([])
    for k in range(len(args.generation_methods)):
        for j in threshs:
            base_recalls[k].append([])
    for i in tqdm(range(len(val_overall_predictions))):
        # base
        for k in range(len(args.generation_methods)):
            labels_i_k = np.zeros(args.num_beams)
            for j in range(len(args.scoring_methods)):
                labels_i_k += np.array(scored_summaries[i][1][j])[k*args.num_beams:(k+1)*args.num_beams]
            best_val = np.max(labels_i_k)
            for j in range(len(threshs)):
                base_recalls[k][j].append(int(np.max(labels_i_k[:threshs[j]]) == best_val))
        # reranked
        preds_i = np.array(val_overall_predictions[i])
        top_preds_idx = np.argsort(preds_i)[::-1]
        labels_i = np.zeros(len(preds_i))
        for j in range(len(args.scoring_methods)):
            labels_i += np.array(scored_summaries[i][1][j])
        best_val = np.max(labels_i)
        for j in range(len(threshs)):
            reranked_recalls[j].append(int(np.max(labels_i[top_preds_idx][:threshs[j]]) == best_val))
        # baseline
        n_top = np.sum(labels_i == best_val)
        for j in range(len(threshs)):
            baseline_recall_j = (scipy.special.binom(args.num_beams, n_top) - scipy.special.binom(args.num_beams - threshs[j], n_top)) / scipy.special.binom(args.num_beams, n_top)
            baseline_recalls[j].append(baseline_recall_j)
    for j in range(len(threshs)):
        recall_baseline = 100 * np.mean(baseline_recalls[j])
        recall_bases = []
        for k in range(len(args.generation_methods)):
            recall_base_k = 100 * np.mean(base_recalls[k][j])
            recall_bases.append(recall_base_k)
        recall_bases = ["{:.4f}".format(i) for i in recall_bases]
        recall_reranker = 100 * np.mean(reranked_recalls[j])
        print("Recall at {}: Baseline: {:.2f} // Base model(s): {} // Reranker model: {:.2f}".format(threshs[j], recall_baseline, recall_bases, recall_reranker))


# 8 
def ttest(scored_summaries, val_overall_predictions, args):
    for j in range(len(args.scoring_methods)):
        reranked_scores = [scored_summaries[i][1][j][np.argmax(np.array(val_overall_predictions[i]))] for i in range(len(val_overall_predictions))]
        reranked_scores = np.array(reranked_scores)
        if args.scoring_methods[j] == "4_bartscore":
            reranked_scores /= 30
        reranked_score = np.mean(reranked_scores)
        all_base_scores = []
        all_p_values = []
        for k in range(len(args.generation_methods)):
            base_scores_k = [scored_summaries[i][1][j][k * args.num_beams] for i in range(len(val_overall_predictions))]
            base_scores_k = np.array(base_scores_k)
            if args.scoring_methods[j] == "4_bartscore":
                base_scores_k /= 30
            base_score_k = np.mean(base_scores_k)
            all_base_scores.append(base_score_k)
            stat, p_value = ttest_ind(base_scores_k, reranked_scores)
            all_p_values.append(p_value)
        all_base_scores = ["{:.4f}".format(i) for i in all_base_scores]
        all_p_values = ["{:.8f}".format(i) for i in all_p_values]
        print("Metric {}, base scores: {}, reranked score: {:.4f}, T-test p-values: {}".format(
            args.scoring_methods[j], all_base_scores, reranked_score, all_p_values))


# 8 
def ttest_all_metrics(scored_summaries, val_overall_predictions, val_labels, args):
    reranked_summaries = [scored_summaries[i][0][np.argmax(np.array(val_overall_predictions[i]))] for i in range(len(val_overall_predictions))]
    all_base_summaries = []
    for k in range(len(args.generation_methods)):
        base_summaries_k = [scored_summaries[i][0][k * args.num_beams] for i in range(len(val_overall_predictions))]
        all_base_summaries.append(base_summaries_k)

    # ROUGE
    scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeLsum'], use_stemmer=args.stemmer)
    reranked_r1s = []
    reranked_r2s = []
    reranked_rls = []
    base_r1s = []
    base_r2s = []
    base_rls = []
    for k in range(len(args.generation_methods)):
        base_r1s.append([])
        base_r2s.append([])
        base_rls.append([])
    for i in range(len(reranked_summaries)):
        # reranked
        summary = reranked_summaries[i]
        summary = pre_rouge_processing(summary, args)
        label = val_labels[i]
        r1, r2, rl = get_rouge_scores(summary, label, scorer, args)
        reranked_r1s.append(100 * r1)
        reranked_r2s.append(100 * r2)
        reranked_rls.append(100 * rl)
        # base
        for k in range(len(args.generation_methods)):
            summary = all_base_summaries[k][i]
            summary = pre_rouge_processing(summary, args)
            r1, r2, rl = get_rouge_scores(summary, label, scorer, args)
            base_r1s[k].append(100 * r1)
            base_r2s[k].append(100 * r2)
            base_rls[k].append(100 * rl)
    # R-1
    all_base_scores = []
    all_p_values = []
    for k in range(len(args.generation_methods)):
        all_base_scores.append(np.mean(base_r1s[k]))
        stat, p_value = ttest_ind(np.array(base_r1s[k]), np.array(reranked_r1s))
        all_p_values.append(p_value)
    all_base_scores = ["{:.4f}".format(i) for i in all_base_scores]
    all_p_values = ["{:.8f}".format(i) for i in all_p_values]
    print("R-1, base scores: {}, reranked score: {:.4f}, T-test p-values: {}".format(all_base_scores, np.mean(reranked_r1s), all_p_values))
    # R-2
    all_base_scores = []
    all_p_values = []
    for k in range(len(args.generation_methods)):
        all_base_scores.append(np.mean(base_r2s[k]))
        stat, p_value = ttest_ind(np.array(base_r2s[k]), np.array(reranked_r2s))
        all_p_values.append(p_value)
    all_base_scores = ["{:.4f}".format(i) for i in all_base_scores]
    all_p_values = ["{:.8f}".format(i) for i in all_p_values]
    print("R-2, base scores: {}, reranked score: {:.4f}, T-test p-values: {}".format(all_base_scores, np.mean(reranked_r2s), all_p_values))
    # R-L
    all_base_scores = []
    all_p_values = []
    for k in range(len(args.generation_methods)):
        all_base_scores.append(np.mean(base_rls[k]))
        stat, p_value = ttest_ind(np.array(base_rls[k]), np.array(reranked_rls))
        all_p_values.append(p_value)
    all_base_scores = ["{:.4f}".format(i) for i in all_base_scores]
    all_p_values = ["{:.8f}".format(i) for i in all_p_values]
    print("R-L, base scores: {}, reranked score: {:.4f}, T-test p-values: {}".format(all_base_scores, np.mean(reranked_rls), all_p_values))

    # BERTScore
    reranked_bs = bertscore_eval(reranked_summaries, val_labels, args, verbose=False).numpy()
    base_bs = []
    for k in range(len(args.generation_methods)):
        base_bs_k = bertscore_eval(all_base_summaries[k], val_labels, args, verbose=False)
        base_bs.append(base_bs_k.numpy())
    all_base_scores = []
    all_p_values = []
    for k in range(len(args.generation_methods)):
        all_base_scores.append(100 * np.mean(base_bs[k]))
        stat, p_value = ttest_ind(np.array(base_bs[k]), np.array(reranked_bs))
        all_p_values.append(p_value)
    all_base_scores = ["{:.4f}".format(i) for i in all_base_scores]
    all_p_values = ["{:.8f}".format(i) for i in all_p_values]
    print("BS, base scores: {}, reranked score: {:.4f}, T-test p-values: {}".format(all_base_scores, 100 * np.mean(reranked_bs), all_p_values))

    # BARTScore
    reranked_bas = bartscore_eval(reranked_summaries, val_labels, args)
    base_bas = []
    for k in range(len(args.generation_methods)):
        base_bas_k = bartscore_eval(all_base_summaries[k], val_labels, args)
        base_bas.append(base_bas_k)
    all_base_scores = []
    all_p_values = []
    for k in range(len(args.generation_methods)):
        all_base_scores.append(np.mean(base_bas[k]))
        stat, p_value = ttest_ind(np.array(base_bas[k]), np.array(reranked_bas))
        all_p_values.append(p_value)
    all_base_scores = ["{:.4f}".format(i) for i in all_base_scores]
    all_p_values = ["{:.8f}".format(i) for i in all_p_values]
    print("BaS, base scores: {}, reranked score: {:.4f}, T-test p-values: {}".format(all_base_scores, np.mean(reranked_bas), all_p_values))


# 9
def summary_length(scored_summaries, val_preds_idx, val_labels, args):
    label_lengths = np.array([len(x.split()) for x in val_labels])
    base_summaries = [scored_summaries[i][0][0] for i in range(len(val_preds_idx))]
    base_lengths = np.array([len(x.split()) for x in base_summaries])
    reranked_summaries = [scored_summaries[i][0][val_preds_idx[i]] for i in range(len(val_preds_idx))]
    reranked_lengths = np.array([len(x.split()) for x in reranked_summaries])
    print("Label summaries mean length: {:.4f}, std: {:.4f}".format(np.mean(label_lengths), np.std(label_lengths)))
    print("Base summaries mean length: {:.4f}, std: {:.4f}".format(np.mean(base_lengths), np.std(base_lengths)))
    print("Reranked summaries mean length: {:.4f}, std: {:.4f}".format(np.mean(reranked_lengths), np.std(reranked_lengths)))


# 10
def new_ngrams(val_texts, scored_summaries, val_preds_idx, val_labels, args):
    max_ngram = 4
    all_label_new_ngrams = []
    all_base_new_ngrams = []
    all_reranked_new_ngrams = []
    for j in range(max_ngram):
        all_label_new_ngrams.append([])
        all_base_new_ngrams.append([])
        for l in range(len(args.generation_methods)):
            all_base_new_ngrams[j].append([])
        all_reranked_new_ngrams.append([])
    for i in tqdm(range(min(len(val_texts), 500))):
        src = val_texts[i]
        src_words = src.lower().split()
        # reranked + label
        reranked_words = scored_summaries[i][0][val_preds_idx[i]].lower().replace("<n>", " ").split()
        label_words = val_labels[i].lower().split()
        for j in range(1, max_ngram + 1):
            src_ngrams = []
            for k in range(len(src_words) - j):
                src_ngrams.append(src_words[k:(k+j)])
            label_new_ngrams = 0
            reranked_new_ngrams = 0
            for k in range(len(label_words) - j):
                if not(label_words[k:(k+j)] in src_ngrams):
                    label_new_ngrams += 1
            for k in range(len(reranked_words) - j):
                if not(reranked_words[k:(k+j)] in src_ngrams):
                    reranked_new_ngrams += 1
            if len(label_words) > j:
                label_new_ngrams /= (len(label_words) - j)
            if len(reranked_words) > j:
                reranked_new_ngrams /= (len(reranked_words) - j)
            all_label_new_ngrams[j-1].append(label_new_ngrams)
            all_reranked_new_ngrams[j-1].append(reranked_new_ngrams)
        # base
        for l in range(len(args.generation_methods)):
            base_words = scored_summaries[i][0][l * args.num_beams].lower().replace("<n>", " ").split()
            for j in range(1, max_ngram + 1):
                src_ngrams = []
                for k in range(len(src_words) - j):
                    src_ngrams.append(src_words[k:(k+j)])
                base_new_ngrams = 0
                for k in range(len(base_words) - j):
                    if not(base_words[k:(k+j)] in src_ngrams):
                        base_new_ngrams += 1
                if len(base_words) > j:
                    base_new_ngrams /= (len(base_words) - j)
                all_base_new_ngrams[j-1][l].append(base_new_ngrams)

    for j in range(1, max_ngram+1):
        base_ngrams = []
        for l in range(len(args.generation_methods)):
            base_ngrams.append(100 * np.mean(all_base_new_ngrams[j-1][l]))
        base_ngrams = ["{:.8f}".format(i) for i in base_ngrams]
        print("For {}-grams, % new in base: {}, % new in reranked summaries: {:.4f}, % new in labels: {:.4f}, ".format(
            j, base_ngrams, 100 * np.mean(all_reranked_new_ngrams[j-1]), 100 * np.mean(all_label_new_ngrams[j-1])))


def export(val_texts, scored_summaries, val_overall_predictions, args):
    base_summaries = [scored_summaries[i][0][0] for i in range(len(val_overall_predictions))]
    reranked_summaries = [scored_summaries[i][0][np.argmax(np.array(val_overall_predictions[i]))] for i in range(len(val_overall_predictions))]
    p = np.random.permutation(len(val_overall_predictions))[:args.n_to_export]
    sampled_texts = [val_texts[idx].replace("\n", " ") for idx in p]
    sampled_base_summaries = [base_summaries[idx].replace("\n", " ") for idx in p]
    sampled_reranked_summaries = [reranked_summaries[idx].replace("\n", " ") for idx in p]
    for i in range(10):
        print(sampled_texts[i])
        print("*"*30)
        print(sampled_base_summaries[i])
        print("*"*30)
        print(sampled_reranked_summaries[i])
    d = {}
    d["src"] = sampled_texts
    d["baseline"] = sampled_base_summaries
    d["SummaReranker"] = sampled_reranked_summaries
    print(len(d["src"]), len(d["baseline"]))
    with open(args.export_name, "wb") as f:
        pickle.dump(d, f)

