# Evaluate fine-tuned models

import argparse
import torch
import time

from utils import *
from data import *
from dataset import *
from dataset_trainer import *
from transfer_utils import *
from model import FTModel
from engine import validate



parser = argparse.ArgumentParser()

parser.add_argument('--seed', type = int, default = 42)
parser.add_argument('--cuda', type = bool, default = True)
parser.add_argument('--accelerate', type = bool, default = True)
parser.add_argument('--debug', type = bool, default = False)
parser.add_argument('--debug_size', type = int, default = 30)

# task
parser.add_argument('--prompt', type = str, default = "summarize: ") # in ["summarize: ", "translate: ", "summarize_translate: "]
parser.add_argument('--add_prompt_to_text', type = bool, default = False)

# data
parser.add_argument('--data_folder', type = str, default = "../../DATASETS/RedditTIFU/data/") # CNNDM / WikiHow / XSum / RedditTIFU / BillSum
parser.add_argument('--val_size_per_language', type = int, default = 100000)
parser.add_argument('--train_max_size', type = int, default = 20000)
parser.add_argument('--val_max_size', type = int, default = 300000)
parser.add_argument('--max_length', type = int, default = 512) # CNNDM: 1024 / WikiHow: 512 / XSum: 512 / RedditTIFU: 512 / BillSum: 768 (700)
parser.add_argument('--check_data_pipe', type = bool, default = False)
parser.add_argument('--truncate_text', type = bool, default = False)
parser.add_argument('--max_text_size', type = int, default = 3000)
parser.add_argument('--compute_r1s', type = bool, default = False)

# model
parser.add_argument('--model_type', type = str, default = "pegasus") # in ["t5", "pegasus", "bart"]
parser.add_argument('--model', type = str, default = "google/pegasus-large") 
# in ["t5-base", google/t5-v1_1-base", "google/pegasus-large", "google/pegasus-cnn_dailymail", "google/pegasus-xsum", "google/pegasus-reddit_tifu", 
# "facebook/bart-large"]
parser.add_argument('--hidden_size', type = int, default = 768) # 512 / 768 / 1024
parser.add_argument('--cache_dir', type = str, default = "../../hf_models/pegasus-large/") 
# in ["t5-base", "t5-base-v1", "pegasus-large", "pegasus-large-cnndm", "pegasus-large-xsum", "pegasus-large-reddit", "bart-large"]
parser.add_argument('--load_model', type = bool, default = True)
parser.add_argument('--save_model_path', type = str, default = "ft_saved_models/pegasus_reddit_train_1/checkpoint-1250/pytorch_model.bin")

# evaluation
parser.add_argument('--train_bs', type = int, default = 12)
parser.add_argument('--inference_bs', type = int, default = 4)
parser.add_argument('--val_dataset', type = str, default = "small_val")

# metrics
# 1 - ROUGE
parser.add_argument('--eval_rouge', type = bool, default = True)
# 2 - BERTScore
parser.add_argument('--eval_bertscore', type = bool, default = False)
# 3 - BARTScore
parser.add_argument('--eval_bartscore', type = bool, default = False)
# 4 - Copying
parser.add_argument('--eval_ngram_overlap', type = bool, default = False)
# 5 - Abstractiveness
parser.add_argument('--eval_new_ngrams', type = bool, default = False)
# 6 - Overlap with source
parser.add_argument('--eval_rouge_text', type = bool, default = False)
# 7_stats
parser.add_argument('--check_correlation', type = bool, default = False)

# summary generation
parser.add_argument('--inference', type = bool, default = False)
parser.add_argument('--generation', type = bool, default = True)
parser.add_argument('--num_beams', type = int, default = 8) # Pegasus: 8, BART: 5
parser.add_argument('--num_return_sequences', type = int, default = 1) # default: 1
parser.add_argument('--max_summary_length', type = int, default = 128) # CNNDM: 128 / WikiHow: 128 / XSum: 64 / Reddit: 128 / BillSum: 256
parser.add_argument('--length_penalty', type = float, default = 0.6) # CNNDM: Pegasus: 0.8, BART: 1.0 / WikiHow: Pegasus: 0.6, BART: 1.0 / XSum: Pegasus: 0.8, BART: 0.8 / Reddit: Pegasus: 0.6 / BillSum: Pegasus: 0.8
parser.add_argument('--repetition_penalty', type = float, default = 1.0) # 1.0
parser.add_argument('--no_repeat_ngram_size', type = int, default = 3) # Pegasus: 0 / BART: 3
parser.add_argument('--stemmer', type = bool, default = True)
parser.add_argument('--n_show_summaries', type = int, default = 1)
parser.add_argument('--rouge_to_use', type = str, default = "rouge_score") # in ["rouge_score", "rouge"]
parser.add_argument('--highlights', type = bool, default = False) # CNNDM: True / WikiHow: False / XSum: False
parser.add_argument('--clean_n', type = bool, default = False) # CNNDM: True / WikiHow: False / XSum: False

args = parser.parse_args()

print("*"*50)
print(args)

#time.sleep(7000)



def main(args):
    # seed
    seed_everything(args.seed)

    # device
    device = torch.device("cpu")
    if args.cuda and torch.cuda.is_available():
        device = torch.device("cuda")
    args.device = device
    print("\nUsing device {}".format(device))

    # data
    val_data = load_data(args.val_dataset, args, individual_txt = args.highlights)

    # tokenizer
    tokenizer = build_tokenizer(args)

    # datasets
    datasets = []
    for x in [("val", val_data)]:
        mode, data = x
        texts, summaries = data
        print(len(texts), len(summaries))
        if args.debug:
            texts = texts[:args.debug_size]
            summaries = summaries[:args.debug_size]
        texts = texts[:args.val_max_size]
        summaries = summaries[:args.val_max_size]
        dataset = InferenceFTDataset(mode, tokenizer, texts, summaries, args)
        datasets.append(dataset)
        print("There are {} {} batches".format(int(len(dataset.texts) / args.inference_bs), mode))
    val_dataset = datasets[0]

    # data loader
    val_loader = torch.utils.data.DataLoader(val_dataset, batch_size = args.inference_bs, shuffle = False)

    # check data pipe
    if args.check_data_pipe:
        check_data_pipe([val_loader])

    # model
    base_model = build_model(args)
    model = FTModel(base_model, args)
    #model = base_model
    n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print("\nThe model has {} trainable parameters".format(n_params))
    model = model.to(device)
    if args.load_model:
        print("loading the weights: {}".format(args.save_model_path))
        model.load_state_dict(torch.load(args.save_model_path))
        print("loaded the model weights!")

    # training
    validate("val", val_loader, [], tokenizer, model, device, args)


if __name__ == '__main__':

    main(args)
