# Generate summary candidates with the fine-tuned models.

import time
import argparse
import sys

sys.path.append("xxx")

from common.utils import *
from common.evaluation import *
from data import *
from dataset import *
from model import *
from engine import *
from model_utils import *



parser = argparse.ArgumentParser()

parser.add_argument('--seed', type = int, default = 42)
parser.add_argument('--cuda', type = bool, default = True)
parser.add_argument('--accelerate', type = bool, default = True)
parser.add_argument('--debug', type = bool, default = False)
parser.add_argument('--debug_size', type = int, default = 5)

# task
parser.add_argument('--prompt', type = str, default = "summarize: ") # in ["summarize: ", "translate: ", "summarize_translate: "]
parser.add_argument('--add_prompt_to_text', type = bool, default = False) # SHOULD BE False

# data
parser.add_argument('--data_folder', type = str, default = "../../DATASETS/RedditTIFU/data/") # CNNDM / WikiHow / XSum / RedditTIFU / BillSum
parser.add_argument('--languages', type = list, default = ["en"]) 
parser.add_argument('--max_length', type = int, default = 512) # CNNDM: 1024 / WikiHow: 512 / XSum: 512 / Reddit: 512 / BillSum: 1024
parser.add_argument('--check_data_pipe', type = bool, default = False)
parser.add_argument('--truncate_text', type = bool, default = False)
parser.add_argument('--max_text_size', type = int, default = 3000)

# model
parser.add_argument('--model_type', type = str, default = "pegasus") # in ["t5", "pegasus", "bart"]
parser.add_argument('--model', type = str, default = "google/pegasus-large") 
# in ["t5-base", google/t5-v1_1-base", "google/pegasus-large", "google/pegasus-cnn_dailymail", "google/pegasus-xsum", "facebook/bart-large", "facebook/bart-large-cnn"]
parser.add_argument('--model_name', type = str, default = "pegasus_reddit_train_1_v2") 
# in ["pegasus_unsupervised", "pegasus_cnndm_first_half_shuffled_1", "pegasus_cnndm_second_half_shuffled_1", "pegasus_cnndm", "bart_unsupervised", "bart_cnndm_first_half_shuffled_1", "bart_cnndm_second_half_shuffled_1", "bart_cnndm"]
parser.add_argument('--hidden_size', type = int, default = 768) # 512 / 768 / 1024`
parser.add_argument('--cache_dir', type = str, default = "../../hf_models/pegasus-large/") 
# in ["t5-base", "t5-base-v1", "pegasus-large", "pegasus-large-cnndm", "pegasus-large-xsum", "bart-large", "bart-large-cnndm"]
parser.add_argument('--load_model', type = bool, default = True)
parser.add_argument('--load_model_path', type = str, default = "../1_base_finetuning/ft_saved_models/pegasus_reddit_train_1/checkpoint-1250/pytorch_model.bin")
parser.add_argument('--ft_model', type = bool, default = True)

# summary generation
### overall
parser.add_argument('--val_dataset', type = str, default = "small_val") 
# in ["train", "first_half_train", "second_half_train", "first_half_train_shuffled", "second_half_train_shuffled", "val", "filtered_val", "small_val", "test"]
parser.add_argument('--val_data_size', type = int, default = 300) 
# in [287113, 143000 / 102000, 144113 / 102045, 13368 / 11332, 13068, 300, 11490 / 11334]
parser.add_argument('--inference_bs', type = int, default = 2) # 1 FOR DIVERSE BEAM SEARCH
parser.add_argument('--save_summaries', type = bool, default = True)
parser.add_argument('--save_summaries_path', type = str, default = "../summaries/Reddit/2_diverse_beam_search/")
### generation method
parser.add_argument('--generation_method', type = str, default = "diverse_beam_search")
# in ["beam_search", "diverse_beam_search", "top_p_sampling", "top_k_sampling"]
parser.add_argument('--num_return_sequences', type = int, default = 15) # default: 15
# 1: beam search
parser.add_argument('--num_beams', type = int, default = 15) # default: 15
# 2: diverse beam search
parser.add_argument('--num_beam_groups', type = int, default = 15) # default: 15
parser.add_argument('--diversity_penalty', type = float, default = 1.0) # default: 1.0
# 3: top-p sampling
parser.add_argument('--top_p', type = float, default = 0.95) # default: 1.0
# 4: top-k sampling
parser.add_argument('--top_k', type = int, default = 50) # default: 50
### other generation hyper-parameters
parser.add_argument('--max_summary_length', type = int, default = 128) # CNNDM: 128 / WikiHow: 128 / XSum: 64 / Reddit: 128 / BillSum: 256
parser.add_argument('--length_penalty', type = float, default = 1.5) # CNNDM: Pegasus: 0.8, BART: 1.0 / WikiHow: Pegasus: 0.6, BART: 1.0 / Xsum: Pegasus: 0.8, BART: 0.8
parser.add_argument('--repetition_penalty', type = float, default = 2.0) # 1.0
parser.add_argument('--no_repeat_ngram_size', type = int, default = 3) # Pegasus: 0 / BART: 3
parser.add_argument('--stemmer', type = bool, default = True)

parser.add_argument('--n_show_summaries', type = int, default = 3)
parser.add_argument('--rouge_to_use', type = str, default = "rouge_score") # in ["rouge_score", "rouge"]
parser.add_argument('--highlights', type = bool, default = False) # CNNDM: True / WikiHow: False XSum: False / Reddit: False
parser.add_argument('--clean_n', type = bool, default = False) # CNNDM: True / WikiHow: False / XSum: False / Reddit: False

# metrics for summaries quality 
# 1 - ROUGE
parser.add_argument('--eval_rouge', type = bool, default = True)
# 2 - BERTScore
parser.add_argument('--eval_bertscore', type = bool, default = True)
# 3 - BARTScore
parser.add_argument('--eval_bartscore', type = bool, default = True)
# 4 - Copying
parser.add_argument('--eval_ngram_overlap', type = bool, default = False)
# 5 - Abstractiveness
parser.add_argument('--eval_new_ngrams', type = bool, default = False)
# 6 - Overlap with source
parser.add_argument('--eval_rouge_text', type = bool, default = False)
# 0_stats
parser.add_argument('--check_correlation', type = bool, default = False)

args = parser.parse_args()

print("*"*50)
print(args)

import time
#time.sleep(10000)



def main(args):
    # seed
    seed_everything(args.seed)

    # device
    device = torch.device("cpu")
    if args.cuda and torch.cuda.is_available():
        device = torch.device("cuda")
    args.device = device
    print("\nUsing device {}".format(device))

    # data
    val_data = load_data(args.val_dataset, args, individual_txt = args.highlights)

    # tokenizer
    tokenizer = build_tokenizer(args)

    # datasets
    datasets = []
    for x in [(args.val_dataset, val_data)]:
        mode, data = x
        texts, summaries = data
        print(len(texts), len(summaries))
        texts = texts[:args.val_data_size]
        summaries = summaries[:args.val_data_size]
        print(len(texts), len(summaries))
        if args.debug:
            texts = texts[:args.debug_size]
            summaries = summaries[:args.debug_size]
        dataset = Dataset(mode, tokenizer, texts, summaries, args)
        print("Total size of dataset: {}".format(len(texts)))
        datasets.append(dataset)
    val_dataset = datasets[0]

    # data loader
    val_loader = torch.utils.data.DataLoader(val_dataset, batch_size = args.inference_bs, shuffle = False)

    # check data pipe
    if args.check_data_pipe:
        check_data_pipe([val_loader])

    # model
    model = build_model(args)
    if args.ft_model:
        model = FTModel(model, args)
    n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print("\nThe model has {} trainable parameters".format(n_params))
    model = model.to(device)
    if args.load_model:
        model.load_state_dict(torch.load(args.load_model_path))
        print("Loaded the model weights!", args.load_model_path)

    # summary generation
    val_texts, val_summaries, val_labels = get_summaries(tokenizer, val_loader, model, device, args)

    # evaluation
    base_results = [val_summaries[i][0] for i in range(len(val_summaries))]
    print("*"*100)
    print("\nTop beam:")
    overall_eval(val_texts, base_results, val_labels, args)

    # export
    num_candidates = len(val_summaries[0])
    if args.save_summaries:
        with open(args.save_summaries_path + "{}/".format(args.val_dataset) + "{}_texts_{}_beams_{}.pkl".format(args.val_dataset, len(val_texts), num_candidates), "wb") as f:
            pickle.dump(val_texts, f)
        with open(args.save_summaries_path + "{}/".format(args.val_dataset) + "{}_summaries_{}_{}_beams_{}.pkl".format(args.val_dataset, args.model_name, len(val_texts), num_candidates), "wb") as f:
            pickle.dump(val_summaries, f)
        with open(args.save_summaries_path + "{}/".format(args.val_dataset) + "{}_labels_{}_beams_{}.pkl".format(args.val_dataset, len(val_texts), num_candidates), "wb") as f:
            pickle.dump(val_labels, f)
        print("saved generated summaries!", args.save_summaries_path + "{}/".format(args.val_dataset))



if __name__ == '__main__':

    main(args)
