# File to evaluate the quality of a set of specific generated summaries.

import sys
import pickle
import argparse 

sys.path.append("xxx")

from common.utils import *
from common.evaluation import *



parser = argparse.ArgumentParser()

parser.add_argument('--cuda', type = bool, default = True)

# data
parser.add_argument('--base_dir', type = str, default = "../summaries/CNNDM/")
parser.add_argument('--generation_methods', type = list, default = [
    "1_beam_search", 
    #"2_diverse_beam_search"
])
parser.add_argument('--sets', type = list, default = [
    #("first_half_train_shuffled", 143000),
    ("small_val", 300),
    #("test", 11490),
    #("test", 11334)
])

parser.add_argument('--n_show_summaries', type = int, default = 0)
parser.add_argument('--rouge_to_use', type = str, default = "rouge_score") # in ["rouge_score", "rouge"]
parser.add_argument('--highlights', type = bool, default = True) # CNNDM: True / WikiHow: False / XSum: False
parser.add_argument('--clean_n', type = bool, default = True) # CNNDM: True / WikiHow: False / XSum: False

parser.add_argument('--stemmer', type = bool, default = True)

# metrics
# 1 - ROUGE
parser.add_argument('--eval_rouge', type = bool, default = True)
# 2 - BERTScore
parser.add_argument('--eval_bertscore', type = bool, default = True)
# 3 - BARTScore
parser.add_argument('--eval_bartscore', type = bool, default = True)
# 4 - Copying
parser.add_argument('--eval_ngram_overlap', type = bool, default = False)
# 5 - Abstractiveness
parser.add_argument('--eval_new_ngrams', type = bool, default = False)
# 6 - Overlap with source
parser.add_argument('--eval_rouge_text', type = bool, default = False)
# 0_stats
parser.add_argument('--check_correlation', type = bool, default = False)

args = parser.parse_args()

print("*"*50)
print(args)



device = torch.device("cpu")
if args.cuda and torch.cuda.is_available():
    device = torch.device("cuda")
args.device = device
print("\nUsing device {}".format(device))

for generation_method in args.generation_methods:
    for set in args.sets:
        (set_name, size) = set
        set_path = args.base_dir + "{}/{}/".format(generation_method, set_name)
        print("*"*50, set_path)
        texts_path = set_path + "{}_texts_{}_beams_15.pkl".format(set_name, size)
        texts = pickle.load(open(texts_path, "rb"))
        labels_path = set_path + "{}_labels_{}_beams_15.pkl".format(set_name, size)
        labels = pickle.load(open(labels_path, "rb"))
        print("# texts: {}, # labels: {}".format(len(texts), len(labels)))
        count = 0
        for file in os.listdir(set_path):
            if (file.endswith("{}_beams_15.pkl".format(size))) and ("summaries" in file) and not("unsupervised" in file):
                summaries = pickle.load(open(set_path + file, "rb"))
                print(len(summaries))
                print("*"*30, "{} - Evaluating {} summaries".format(count, file))
                base_summaries = [x[0] for x in summaries]
                print(len(base_summaries))
                overall_eval(texts, base_summaries, labels, args)
                count += 1

