# Evaluate the performance of a trained re-ranker.

import argparse
import sys
import time

sys.path.append("xxx")

from tqdm import tqdm

from common.utils import seed_everything
from common.evaluation import *
from utils import *
from data import load_data
from dataset import MultitaskRerankingDataset
from training_utils import *
from model import ModelMultitaskBinary
from evaluation_utils import *



parser = argparse.ArgumentParser()

parser.add_argument('--seed', type=int, default = 42)
parser.add_argument('--cuda', type=bool, default = True)

# data
parser.add_argument('--data_folder', type=str, default = "../../DATASETS/RedditTIFU/data/")
parser.add_argument('--generation_methods', type = list, default = [
    "1_beam_search",
    "2_diverse_beam_search",
    #"3_top_p_sampling",
    #"4_top_k_sampling",
])
parser.add_argument('--scoring_methods', type=list, default = [
    #"1_mean_rouge",
    "1a_rouge_1",
    "1b_rouge_2",
    "1c_rouge_l",
    #"2_bertscore",
    #"4_bartscore"
])
parser.add_argument('--scored_summaries_path', type = str, default = "../reranking_data/Reddit/")
parser.add_argument('--sep_symbol', type=str, default = "[SEP]")
parser.add_argument('--highlights', type = bool, default = False) # True for CNNDM, False for WikiHow / XSum / RedditTIFU
# val
parser.add_argument('--val_dataset', type=str, default = "small_val")
parser.add_argument('--val_data_size', type=int, default = 300) # 300 / 11490 / 11334 / 4222
parser.add_argument('--val_size', type=int, default = 300)

# base model
parser.add_argument('--model_name', type=str, default = "pegasus_reddit_train_1_v2")
# in ["pegasus_unsupervised", "pegasus_cnndm_first_half_shuffled_1", "pegasus_cnndm_second_half_shuffled_1", "pegasus_cnndm",
# "bart_unsupervised", "bart_cnndm_first_half_shuffled_1", "bart_cnndm_second_half_shuffled_1", "bart_cnndm"]
parser.add_argument('--num_beams', type=int, default = 15)

# model
# candidate selection
parser.add_argument('--pos_neg_construction', type = str, default = "overall_sum") # in ["overall_sum", "per_task"]
parser.add_argument('--filter_out_duplicates', type=bool, default = True)
parser.add_argument('--prune_candidates', type=bool, default = True)
parser.add_argument('--sampling_strat', type=str, default = "bottom")  # in ["random", "bottom"]
parser.add_argument('--n_negatives', type=int, default = 1)
parser.add_argument('--max_n_candidates', type = int, default = 3)
# encoder
parser.add_argument('--model', type=str, default = "roberta-large")  #
parser.add_argument('--model_type', type=str, default = "roberta")  # in ["bert", "roberta", "t5", "pegasus"]
parser.add_argument('--cache_dir', type=str, default = "../../hf_models/roberta-large/")
parser.add_argument('--hidden_size', type=int, default = 1024) # 768 / 1024
parser.add_argument('--non_linear_repres', type=bool, default = True)
parser.add_argument('--max_length', type=int, default = 448) # 384 / 448
parser.add_argument('--max_summary_length', type=int, default = 64) # 128 / 64
parser.add_argument('--encode_begin_end', type = bool, default = False)
parser.add_argument('--encode_head_tail', type = bool, default = False)
parser.add_argument('--pack_text_summaries', type = bool, default = False)
# shared bottom
parser.add_argument('--use_shared_bottom', type = bool, default = True)
parser.add_argument('--bottom_hidden_size', type = int, default = 1024)
# experts
parser.add_argument('--num_experts', type=int, default = 6)
#parser.add_argument('--noisy_gating', type=bool, default=True)
parser.add_argument('--k', type=int, default = 3)
parser.add_argument('--use_aux_loss', type = bool, default = False)
parser.add_argument('--expert_hidden_size', type = int, default = 1024)
# tower
parser.add_argument('--tower_hidden_size', type = int, default = 1024)
# weights
parser.add_argument('--load_model', type=bool, default = True)
parser.add_argument('--load_model_path', type=str, default = "saved_models/reddit/multitask_3_tasks_ablation_10/checkpoint-600/pytorch_model.bin")

# optimization
parser.add_argument('--inference_bs', type=int, default = 60)

# generation
parser.add_argument('--stemmer', type = bool, default = True)
parser.add_argument('--n_show_summaries', type = int, default = 0)
parser.add_argument('--clean_n', type = bool, default = False)
parser.add_argument('--rouge_to_use', type = str, default = "rouge_score")

# evaluation
# 0
parser.add_argument('--evaluate_metrics_correlation', type = bool, default = False)
# 1
parser.add_argument('--evaluate_quant_metrics', type = bool, default = True)
# 2
parser.add_argument('--evaluate_tasks_probs', type = bool, default = True)
# 3
parser.add_argument('--evaluate_qualitative_samples', type = bool, default = False)
parser.add_argument('--evaluate_qualitative_full_predictions', type = bool, default = True)
# 4
parser.add_argument('--evaluate_reranker_on_label', type = bool, default = False)
# 5
parser.add_argument('--evaluate_ranking_correlation', type = bool, default = False)
# 6
parser.add_argument('--evaluate_per_length', type = bool, default = False)
# 7 
parser.add_argument('--evaluate_recall', type = bool, default = False)
# 8 
parser.add_argument('--evaluate_ttest', type = bool, default = False)
# 9
parser.add_argument('--evaluate_summary_length', type = bool, default = False)
# 10
parser.add_argument('--evaluate_abstractiveness', type = bool, default = False)

# metrics
# 1 - ROUGE
parser.add_argument('--eval_rouge', type = bool, default = True)
# 2 - BERTScore
parser.add_argument('--eval_bertscore', type = bool, default = True)
# 3 - BARTScore
parser.add_argument('--eval_bartscore', type = bool, default = True)
# 4 - Copying
parser.add_argument('--eval_ngram_overlap', type = bool, default = False)
# 5 - Abstractiveness
parser.add_argument('--eval_new_ngrams', type = bool, default = False)
# 6 - Overlap with source
parser.add_argument('--eval_rouge_text', type = bool, default = False)
parser.add_argument('--check_correlation', type = bool, default = False)

# export
parser.add_argument('--export', type = bool, default = False)
parser.add_argument('--export_name', type = str, default = "predictions/reddit_pegasus_3_tasks_ablation_5_test_100.pkl")
parser.add_argument('--n_to_export', type = int, default = 100)

args = parser.parse_args()
args.n_tasks = len(args.scoring_methods)

print("*" * 50)
print(args)

#time.sleep(10000)



def main(args):
    # seed
    seed_everything(args.seed)

    # device
    device = torch.device("cpu")
    if args.cuda and torch.cuda.is_available():
        device = torch.device("cuda")
    args.device = device
    print("Using device: {}".format(device))

    # tokenizer
    tokenizer = build_tokenizer(args)

    # data
    set = args.val_dataset
    size = args.val_data_size
    texts, summaries, scored_summaries = load_data(set, size, args, individual_txt = args.highlights)
    print("loaded new data!", len(texts), len(summaries), len(scored_summaries), len(scored_summaries[0]),
          len(scored_summaries[0][0]), len(scored_summaries[0][1]))
    p = np.random.permutation(len(texts))
    p = p[:args.val_size]
    texts = [texts[i] for i in p]
    summaries = [summaries[i] for i in p]
    scored_summaries = [scored_summaries[i] for i in p]

    # 0
    if args.evaluate_metrics_correlation:
        metrics_correlation(scored_summaries, args)
        raise Exception 

    # dataset
    mode = "val"
    val_dataset = MultitaskRerankingDataset(mode, tokenizer, texts, scored_summaries, summaries, args)
    print("There are {} {} batches".format(int(len(val_dataset.texts) / args.inference_bs), set))
    
    # data loader
    val_loader = torch.utils.data.DataLoader(val_dataset, batch_size = args.inference_bs, shuffle = False)

    # model
    pretrained_model = build_model(args)
    model = ModelMultitaskBinary(pretrained_model, tokenizer, args)
    n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print("\nThe model has {} trainable parameters".format(n_params))
    model = model.to(device)
    if args.load_model:
        model.load_state_dict(torch.load(args.load_model_path))
        print("Loaded the model weights!", args.load_model_path)

    # inference
    val_texts = []
    val_labels = []
    val_preds_idx = []
    val_predictions = []
    val_overall_predictions = []
    for i, batch in tqdm(enumerate(val_loader)):
        model.zero_grad()

        mode = batch["mode"]
        batch_texts = batch["text"]
        val_texts += batch_texts
        batch_labels = batch["label"]
        val_labels += batch_labels

        text_ids = batch["text_input_ids"].to(device)
        text_mask = batch["text_attn_mask"].to(device)
        cand_ids = batch["cand_input_ids"].to(device)
        cand_mask = batch["cand_attn_mask"].to(device)
        text_and_summaries_ids = batch["text_and_summaries_input_ids"].to(device)
        text_and_summaries_mask = batch["text_and_summaries_attn_mask"].to(device)
        scores = batch["scores"]

        with torch.no_grad():
            output = model(mode, text_ids, text_mask, text_and_summaries_ids, text_and_summaries_mask, None, None, scores)
            predictions_idx = output["total_predictions_idx"]
            val_preds_idx += predictions_idx
            val_predictions.append(output["prediction_sum"].item()) 
            val_overall_predictions += output["overall_predictions"]
    print("# texts: {}, # summaries: {}, # preds idx: {}, predictions: {}".format(len(val_texts), len(val_labels), len(val_preds_idx), len(val_predictions)))
    print("Mean predictions: {:.4f}".format(np.mean(np.array(val_predictions))))

    val_preds = []
    for i in range(len(val_preds_idx)):
        #val_preds.append(scored_summaries[i][0][0])
        val_preds.append(scored_summaries[i][0][val_preds_idx[i]])
        #val_preds.append(scored_summaries[i][0][0])

    # evaluation
    print("*"*100)
    print("\nEval:")
    # 1 - quantitative metrics
    if args.evaluate_quant_metrics:
        print("\n", ">"*20, "\nEvaluate - quantitative metrics:")
        overall_eval(val_texts, val_preds, val_labels, args)
    # 2 - probability distribution across experts for each task
    if args.evaluate_tasks_probs:
        print("\n", ">"*20, "\nEvaluate - tasks probabilities:")
        model.moe.display_tasks_probs()
    # 3 - find good qualitative examples of re-ranked summaries
    if args.evaluate_qualitative_samples:
        print("\n", ">"*20, "\nEvaluate - qualitative samples:")
        qualitative_samples(val_texts, scored_summaries, val_preds_idx, val_overall_predictions, val_labels, args)
    # 4 - see how the reranker ranks the label
    if args.evaluate_reranker_on_label:
        print("\n", ">"*20, "\nEvaluate - reranker on label:")
        reranker_on_label(val_dataset, val_labels, model, args)
    # 5 - ranking correlation between beam/reranker ranks, true ranks
    if args.evaluate_ranking_correlation:
        print("\n", ">"*20, "\nEvaluate - ranking correlation:")
        ranking_correlations(scored_summaries, val_overall_predictions, args)
    # 6 - per length evaluation
    if args.evaluate_per_length:
        print("\n", ">"*20, "\nEvaluate - per length:")
        length_evaluation(val_texts, scored_summaries, val_preds_idx, val_labels, args)
    # 7 - recall
    if args.evaluate_recall:
        print("\n", ">"*20, "\nEvaluate - recall:")
        recall_evaluation(scored_summaries, val_overall_predictions, args)
    # 8 - statistical significance 
    if args.evaluate_ttest:
        print("\n", ">"*20, "\nEvaluate - T-test:")
        ttest(scored_summaries, val_overall_predictions, args)
        ttest_all_metrics(scored_summaries[:len(val_overall_predictions)], val_overall_predictions, val_labels[:len(val_overall_predictions)], args)
    # 9 - summary length:
    if args.evaluate_summary_length:
        print("\n", ">"*20, "\nEvaluate - summary length:")
        summary_length(scored_summaries, val_preds_idx, val_labels, args)
    # 10 - abstractivness:
    if args.evaluate_abstractiveness:
        print("\n", ">"*20, "\nEvaluate - abstractiveness:")
        new_ngrams(val_texts, scored_summaries, val_preds_idx, val_labels, args)

    # export
    if args.export:
        export(val_texts, scored_summaries, val_overall_predictions, args)
        print("saved!")



if __name__ == '__main__':
    main(args)
