# Score each summary candidate according to a specified metric.

import argparse
import pickle
import sys
import gc

sys.path.append("xxx")

from time import time 
from nltk.tokenize import sent_tokenize
from tqdm import tqdm
from rouge_score import rouge_scorer
from datasets import load_metric
from bert_score import score as bertscore_score

from common.utils import *
from common.bart_score import BARTScorer
from common.evaluation import overall_eval



parser = argparse.ArgumentParser()

parser.add_argument('--seed', type = int, default = 42)
parser.add_argument('--cuda', type = bool, default = True)

# data
parser.add_argument('--summaries_path', type = str, default = "../summaries/Reddit/2_diverse_beam_search/") 
# 1_beam_search / 2_diverse_beam_search / 3_top_p_sampling / 4_top_k_sampling
parser.add_argument('--set', type = str, default = "test") 
# train / first_half_train / second_half_train / first_half_train_shuffled / second_half_train_shuffled / filtered_val / small_val / val / test
parser.add_argument('--dataset_size', type = str, default = 4222) 
# CNNDM: 287113 / 143000 / 144113 / 143000 / 144113 / 13068 / 300 / 13368 / 11490
parser.add_argument('--size_to_score', type = int, default = 4222)

# model
parser.add_argument('--model_name', type = str, default = "pegasus_reddit_train_1_v2") 
# in ["pegasus_cnndm_first_half", "pegasus_cnndm_first_half_shuffled_1", "pegasus_cnndm_second_half_shuffled_1", "pegasus_cnndm", 
# "bart_cnndm_first_half_shuffled_1", "bart_cnndm_second_half_shuffled_1", "bart_cnndm"]
parser.add_argument('--num_candidates', type = int, default = 15)

# METRIC
parser.add_argument('--label_metric', type = str, default = "rouge_l")
# in ["mean_rouge", "rouge_1", "rouge_2", "rouge_l", "bertscore", "bleurt", "bartscore"]

# evaluation
parser.add_argument('--stemmer', type = bool, default = True)
parser.add_argument('--n_show_summaries', type = int, default = 0)
parser.add_argument('--highlights', type = bool, default = False)
parser.add_argument('--clean_n', type = bool, default = False)
parser.add_argument('--rouge_to_use', type = str, default = "rouge_score")

# export
parser.add_argument('--save_scores', type = bool, default = True)
parser.add_argument('--scored_summaries_path', type = str, default = "../reranking_data/Reddit/2_diverse_beam_search/1c_rouge_l/") 
# 1_mean_rouge / 1a_rouge_1 / 1b_rouge_2 / 1c_rouge_l / 2_bertscore / 3_bleurt / 4_bartscore

# metrics
parser.add_argument('--eval_top_candidate', type = bool, default = False)
parser.add_argument('--eval_oracle', type = bool, default = True)
# 1 - ROUGE
parser.add_argument('--eval_rouge', type = bool, default = True)
# 2 - BERTScore
parser.add_argument('--eval_bertscore', type = bool, default = True)
# 3 - BARTScore
parser.add_argument('--eval_bartscore', type = bool, default = True)
# 4 - Copying
parser.add_argument('--eval_ngram_overlap', type = bool, default = False)
# 5 - Abstractiveness
parser.add_argument('--eval_new_ngrams', type = bool, default = False)
# 6 - Overlap with source
parser.add_argument('--eval_rouge_text', type = bool, default = False)
# 0_stats
parser.add_argument('--check_correlation', type = bool, default = False)

args = parser.parse_args()

print("*"*50)
print(args)

#time.sleep(20000)



def main(args):
    # seed
    seed_everything(args.seed)

    # device
    device = torch.device("cpu")
    if args.cuda and torch.cuda.is_available():
        device = torch.device("cuda")
    args.device = device
    print("Using device: {}".format(device))

    # load summaries
    summaries_path = args.summaries_path + "{}/{}_summaries_{}_{}_beams_{}.pkl".format(args.set, args.set, args.model_name, args.dataset_size, args.num_candidates)
    with open(summaries_path, "rb") as f:
        summaries = pickle.load(f)
    print("Loaded {} summaries".format(len(summaries)))

    # load labels
    labels_path = args.summaries_path + "{}/{}_labels_{}_beams_{}.pkl".format(args.set, args.set, args.dataset_size, args.num_candidates)
    with open(labels_path, "rb") as f:
        labels = pickle.load(f)
    print("Loaded {} labels".format(len(labels)))

    summaries = summaries[:args.size_to_score]
    labels = labels[:args.size_to_score]
    for i in range(3):
        print(summaries[0][i])
    #raise Exception

    # score summaries against the labels
    print("\nSCORING SUMMARIES WITH: {}".format(args.label_metric))

    # init
    if "rouge" in args.label_metric:
        scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeLsum'], use_stemmer = True)
        scores = []
    elif args.label_metric == "bertscore":
        all_summaries = []
        all_labels = []
        for j in range(args.num_candidates):
            all_summaries.append([])
    elif args.label_metric == "bleurt":
        metric = load_metric('bleurt', keep_in_memory = True)
    elif args.label_metric == "bartscore":
        all_summaries = []
        all_labels = []
        for j in range(args.num_candidates):
            all_summaries.append([])
        bart_scorer = BARTScorer(device=device, checkpoint='facebook/bart-large-cnn')
        scores = []

    t1 = time()
    # loop
    for i in tqdm(range(len(summaries))):
        summaries_i = summaries[i]
        label = labels[i]
        if "rouge" in args.label_metric:
            scores_i = get_rouge_scores(label, summaries_i, scorer, args)
            scores.append(scores_i)
        elif args.label_metric in  ["bertscore", "bartscore"]:
            all_labels.append(label)
            for j in range(len(summaries_i)):
                all_summaries[j].append(summaries_i[j])
        elif args.label_metric == "bleurt":
            metric.add_batch(predictions=summaries_i, references=[label] * len(summaries_i))

    # conclusion
    if args.label_metric == "bertscore":
        all_f1 = []
        for j in range(len(all_summaries)):
            print(j, len(all_summaries[j]))
            _, _, f1 = bertscore_score(all_summaries[j], all_labels, lang='en', verbose=True, batch_size=16)
            all_f1.append(f1)
            gc.collect()
        scores = [[all_f1[j][i].item() for j in range(len(all_f1))] for i in range(len(all_f1[0]))]
    elif args.label_metric == "bleurt":
        score = metric.compute()
        bleurt_scores = score["scores"]
        print(len(bleurt_scores))
        scores = []
        for i in range(len(summaries)):
            scores.append(bleurt_scores[(i*args.num_candidates):((i+1)*args.num_candidates)])
    elif args.label_metric == "bartscore":
        all_bartscores = []
        for j in range(len(all_summaries)):
            print(j, len(all_summaries[j]))
            bartscores = bart_scorer.score(all_labels, all_summaries[j], batch_size=16)
            all_bartscores.append(bartscores)
        scores = [[all_bartscores[j][i] for j in range(len(all_bartscores))] for i in range(len(all_bartscores[0]))]
    t2 = time()
    print("Time to get the scores: {:.4f}".format(t2-t1))
    print(len(scores), len(scores[0]))
    print(type(scores[0]))
    print(scores[0])
    scored_summaries = [[summaries[i], scores[i]] for i in range(len(summaries))]
    top_scores = [scores[i][0] for i in range(len(summaries))]
    oracle_scores = [np.max(scores[i]) for i in range(len(summaries))]

    print(len(scored_summaries))
    print("Mean score (top beam): {:.4f}".format(np.mean(np.array(top_scores))))
    print("ORACLE score: {:.4f}".format(np.mean(np.array(oracle_scores))))

    if args.save_scores:
        save_path = args.scored_summaries_path + "{}/{}_scored_summaries_{}_{}_beams_{}.pkl".format(args.set, args.set, args.model_name, args.dataset_size, args.num_candidates)
        with open(save_path, "wb") as f:
            pickle.dump(scored_summaries, f)
            print("saved new data!", save_path)

    if args.eval_top_candidate:
        val_summaries = [summaries[i][0] for i in range(len(summaries))]
        print("\n\n")
        print("*"*50)
        print("Top candidate evaluation:")
        overall_eval(None, val_summaries, labels, args)

    if args.eval_oracle:
        val_summaries = [summaries[i][np.argmax(scores[i])] for i in range(len(summaries))]
        print("\n\n")
        print("*"*50)
        print("Oracle evaluation:")
        overall_eval(None, val_summaries, labels, args)


def get_rouge_scores(label, summaries_i, scorer, args):
    scores_i = []
    for j in range(len(summaries_i)):
        summary = summaries_i[j]
        if args.clean_n:
            summary = summary.replace("<n>", " ")
        if args.highlights:
            summary = "\n".join(sent_tokenize(summary))
        rouge_scores = scorer.score(label, summary)
        r1 = 100 * rouge_scores["rouge1"].fmeasure
        r2 = 100 * rouge_scores["rouge2"].fmeasure
        rl = 100 * rouge_scores["rougeLsum"].fmeasure
        if args.label_metric == "mean_rouge":
            score = (r1 + r2 + rl) / 3
        elif args.label_metric == "rouge_1":
            score = r1
        elif args.label_metric == "rouge_2":
            score = r2
        elif args.label_metric == "rouge_l":
            score = rl
        scores_i.append(score)

    return scores_i



if __name__ == '__main__':

    main(args)
