# File to evaluate the quality of a set of specific generated summaries.

import sys
import pickle
import argparse 

sys.path.append("xxx")

from common.utils import *
from common.evaluation import *



parser = argparse.ArgumentParser()

parser.add_argument('--seed', type = int, default = 42)
parser.add_argument('--cuda', type = bool, default = True)

# data
parser.add_argument('--base_dir', type = str, default = "../summaries/CNNDM/")
parser.add_argument('--generation_methods', type = list, default = [
    #"1_beam_search", 
    #"2_diverse_beam_search",
    "3_top_p_sampling",
    #"4_top_k_sampling"
])
parser.add_argument('--sets', type = list, default = [
    #("small_val", 300),
    ("test", 11490),
    # ("test", 11334)
])
parser.add_argument('--model', type = str, default = "pegasus_cnndm")
parser.add_argument('--n_beams', type = int, default = 15)

parser.add_argument('--n_show_summaries', type = int, default = 0)
parser.add_argument('--rouge_to_use', type = str, default = "rouge_score") # in ["rouge_score", "rouge"]
parser.add_argument('--highlights', type = bool, default = True) # CNNDM: True / WikiHow: False / XSum: False
parser.add_argument('--clean_n', type = bool, default = True) # CNNDM: True / WikiHow: False / XSum: False

parser.add_argument('--stemmer', type = bool, default = True)

# metrics
# 1 - ROUGE
parser.add_argument('--eval_rouge', type = bool, default = True)
# 2 - BERTScore
parser.add_argument('--eval_bertscore', type = bool, default = True)
# 3 - BARTScore
parser.add_argument('--eval_bartscore', type = bool, default = True)
# 4 - Copying
parser.add_argument('--eval_ngram_overlap', type = bool, default = False)
# 5 - Abstractiveness
parser.add_argument('--eval_new_ngrams', type = bool, default = False)
# 6 - Overlap with source
parser.add_argument('--eval_rouge_text', type = bool, default = False)
# 0_stats
parser.add_argument('--check_correlation', type = bool, default = False)

args = parser.parse_args()

print("*"*50)
print(args)



seed_everything(args.seed)

device = torch.device("cpu")
if args.cuda and torch.cuda.is_available():
    device = torch.device("cuda")
args.device = device
print("\nUsing device {}".format(device))

for generation_method in args.generation_methods:
    for set in args.sets:
        (set_name, size) = set
        set_path = args.base_dir + "{}/{}/".format(generation_method, set_name)
        texts_path = set_path + "{}_texts_{}_beams_{}.pkl".format(set_name, size, args.n_beams)
        texts = pickle.load(open(texts_path, "rb"))
        labels_path = set_path + "{}_labels_{}_beams_{}.pkl".format(set_name, size, args.n_beams)
        labels = pickle.load(open(labels_path, "rb"))
        summaries_path = set_path + "{}_summaries_{}_{}_beams_{}.pkl".format(set_name, args.model, size, args.n_beams)
        summaries = pickle.load(open(summaries_path, "rb"))
        print("Summaries from: {}".format(summaries_path))
        print("# texts: {}, # labels: {}, # summaries: {}".format(len(texts), len(labels), len(summaries)))
        base_summaries = [x[0] for x in summaries]
        print("\n\nEvaluating the TOP baseline\n")
        print(len(base_summaries))
        overall_eval(texts, base_summaries, labels, args)
        random_summaries = [x[np.random.randint(args.n_beams)] for x in summaries]
        print("\n\nEvaluating the RANDOM baseline\n")
        print(len(random_summaries))
        #overall_eval(texts, random_summaries, labels, args)

