# Manually inspect the data generated for re-ranking.

import argparse
import sys
import pickle

sys.path.append("xxx")

from tqdm import tqdm

from common.utils import seed_everything



parser = argparse.ArgumentParser()

parser.add_argument('--seed', type=int, default = 42)

# data
parser.add_argument('--generation_methods', type = list, default = [
    #"1_beam_search",
    "2_diverse_beam_search",
    #"3_top_p_sampling",
    #"4_top_k_sampling",
])
parser.add_argument('--scoring_methods', type=list, default = [
    #"1_mean_rouge",
    "1a_rouge_1",
    "1b_rouge_2",
    "1c_rouge_l",
    "2_bertscore",
    "4_bartscore"
])
parser.add_argument('--scored_summaries_path', type = str, default = "../reranking_data/CNNDM/")

# base model
parser.add_argument('--model_name', type=str, default = "pegasus_cnndm")
# in ["pegasus_unsupervised", "pegasus_cnndm_first_half_shuffled_1", "pegasus_cnndm_second_half_shuffled_1", "pegasus_cnndm",
# "bart_unsupervised", "bart_cnndm_first_half_shuffled_1", "bart_cnndm_second_half_shuffled_1", "bart_cnndm"]
parser.add_argument('--num_beams', type=int, default = 15)

args = parser.parse_args()
args.n_tasks = len(args.scoring_methods)

print("*" * 50)
print(args)



def main(args):
    # seed
    seed_everything(args.seed)

    set_ = "small_val"
    size_ = 300
    for generation_method in args.generation_methods:
        gen_scored_summaries = []
        for j in range(len(args.scoring_methods)):
            scored_summaries_j = []
            scored_summaries_path_j = args.scored_summaries_path + "{}/{}/{}/{}_scored_summaries_{}_{}_beams_{}.pkl".format(
                generation_method, args.scoring_methods[j], set_, set_, args.model_name, size_, args.num_beams
            )
            print(scored_summaries_path_j)
            with open(scored_summaries_path_j, "rb") as f:
                scored_summaries_j = pickle.load(f)
            print(len(scored_summaries_j))
            for i in range(3):
                print(scored_summaries_j[0][0][i], scored_summaries_j[0][1][i])



if __name__ == '__main__':
    main(args)
