import argparse
import json
import os

import numpy as np
import torch
from tqdm import tqdm
from transformers import AutoModel, AutoTokenizer

from response_selection.preprocessing import get_syndd_corpus
from response_selection.model import ResponseSelection, ResponseSelectionforMultiTask
from response_selection.dataset import SelectionDataset
from response_selection.utils import (
    PREFIX_DIR,
    UTTR_TOKEN,
    LMDICT,
    set_random_seed,
    set_logger,
    load_model,
)


parser = argparse.ArgumentParser(description="Process arguments for evaluating a response selection model")
parser.add_argument("--random-seed", type=int, default=42, help="random seed during training")
parser.add_argument("--log-path", type=str, default="result")
parser.add_argument("--corpus", default="dd", choices=["persona", "dd"])
parser.add_argument("--hardneg-smoothing",
    type=float,
    default=0,
    help="hard neg prediction에 사용할 레이블 비율. 이걸 0.2로 하로 hard neg가 5개면 하나의 hard neg가 0.2/5=0.04를 나눠갖고, positive는 0.8의 label값을 가짐.",
)

# Arguments for additional loss
# neg type classification loss (multi-task learning)
parser.add_argument("--pred-neg-type", type=str, default="False", choices=["True", "False"])
parser.add_argument("--pred-neg-type-alpha", type=float, default=0.1)
# triplet marginal loss
parser.add_argument("--triplet-margin", type=float, default=-1)
parser.add_argument("--triplet-alpha", type=float, default=0)
# ntxent loss
parser.add_argument("--ntxent-temp", type=float, default=-1)
parser.add_argument("--ntxent-alpha", type=float, default=0)

# Arguments in the testing model setup
parser.add_argument("--is-curriculum", action="store_true", help="default=False")
parser.add_argument("--is-shuffle", action="store_true", help="default=False, given=True")

parser.add_argument("--train-num-candidates", type=int, default=11, help="total number of candidates(pos + negs) used during training")
parser.add_argument("--train-num-hard-negs", type=int, default=5, help="total number of hard negatives used during training")
parser.add_argument("--lmtype", type=str, default="bert", choices=["bert", "roberta", "electra"])
parser.add_argument("--neg-type", type=str, default="gpt", choices=["random", "human", "gpt", "syn"])
parser.add_argument("--batch-size", type=int, default=32)
parser.add_argument("--model-path", type=str, default="logs_{}/{}_{}_batch32_candi{}_hard{}_seed{}{}{}/model")
parser.add_argument("--target-epoch", type=int, default=0)
parser.add_argument(
    "--approach",
    type=str,
    default="none",
    choices=[
        "none",
        "bm25",
        "maskandfill",
        "kwsim",
        "direct_w_ans",
        "direct_wo_ans",
        "direct_wo_ans_10p",  # direct_wo_ans by using 10% of human-written samples
        "direct_wo_ans_1p",  # direct_wo_ans by using 1% of human-written samples
        "direct_wo_ans_01p",  # direct_wo_ans by using 0.1% of human-written samples
        "direct_wo_ans_01p_reuse",  # direct_wo_ans by using 0.1% of human-written samples + reuse
        "direct_aug_10000",  # DTI by adding augmented 10,000 dataset
        "direct_1_shot",
        "direct_0_shot",
        "meta",
        "semi_hard",  # random: "semi_hard"
    ],
)

# Arguments for choosing test dataset
parser.add_argument("--test-candidate-num", type=int, default=6, help="total number of candidates(pos + negs) used during testing")
parser.add_argument("--test-num-hard-negs", type=int, default=5, help="total number of hard negatives used during testing")
parser.add_argument("--test-file-name", type=str, default="dailydialog++", choices=["dailydialog++", "dailydialog_gpt"])
parser.add_argument("--test-neg-type", type=str, default="random", choices=["random", "human", "gpt", "all"])
parser.add_argument("--test-approach", type=str, default="direct_w_ans", choices=["none", "direct_w_ans", "direct_wo_ans", "direct_0_shot", "meta"])
parser.add_argument(
    "--test-object",
    type=str,
    default="random_negative_responses",
    choices=[
        "random_negative_responses",
        "adversarial_negative_responses",
        "gpt3_negative_responses",
    ],
)


def main(args):
    """
    Testing pipeline

    :param args: training arguments
    """
    # set random seed
    set_random_seed(args.random_seed)

    # set logger
    logger = set_logger(f"response-selection-model-testing")

    # record and report params
    for k, v in vars(args).items():
        logger.info(f"{k}: {v}")

    device = torch.device(f"cuda:0" if torch.cuda.is_available() else "cpu")
    model_name = LMDICT[args.lmtype]

    # set tokenizer
    logger.info(f"[+] Load Tokenizer: {model_name}")
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    special_tokens_dict = {"additional_special_tokens": [UTTR_TOKEN]}
    tokenizer.add_special_tokens(special_tokens_dict)

    model_list = []

    # load pretrained model weight
    logger.info(f"[+] Load Pretrained Model: {model_name}")
    pretrained_model = AutoModel.from_pretrained(model_name)
    pretrained_model.resize_token_embeddings(len(tokenizer))

    # load response selection model
    if not args.pred_neg_type:
        logger.info(f"[+] Load Model: ResponseSelection")
        model = ResponseSelection(pretrained_model)
    else:
        logger.info(f"[+] Load Model: ResponseSelection for Multi-task learning")
        model = ResponseSelectionforMultiTask(pretrained_model)
    
    model = load_model(
        model=model,
        model_path=args.model_path,
        epoch=args.target_epoch,
        len_tokenizer=len(tokenizer),
    )

    model.to(device)
    model.eval()
    model_list.append(model)

    # load dataset
    raw_test_dataset = get_syndd_corpus(args.test_file_name, "test", args.test_approach)

    text_fname = (
        PREFIX_DIR
        + "data/selection_{}/text_cand{}".format(args.lmtype, args.test_candidate_num)
        + "_{}".format(args.test_neg_type)
        + "_{}".format(args.test_approach)
        + "_{}.pck"
    )
    tensor_fname = (
        PREFIX_DIR
        + "data/selection_{}/tensor_cand{}".format(args.lmtype, args.test_candidate_num)
        + "_{}".format(args.test_neg_type)
        + "_{}".format(args.test_approach)
        + "_{}.pck"
    )

    test_dataset = SelectionDataset(
        raw_dataset=raw_test_dataset,
        tokenizer=tokenizer,
        setname="test",
        target_object=args.test_object,
        max_seq_len=128,
        num_candidates=args.test_candidate_num,
        num_hard_negs=args.test_num_hard_negs,
        uttr_token=UTTR_TOKEN,
        txt_save_fname=text_fname,
        tensor_save_fname=tensor_fname,
    )

    total_item_list = []
    dataset_length = len(test_dataset)

    for idx in tqdm(range(dataset_length)):
        pred_list_for_current_context = []
        uncertainty_list_for_current_context = []

        sample = [el[idx] for el in test_dataset.feature]

        assert len(sample) == 2 * args.test_candidate_num + 1

        ids = torch.stack([sample[i] for i in range(args.test_candidate_num)]).to(device)
        mask = torch.stack(
            [
                sample[i + args.test_candidate_num]
                for i in range(args.test_candidate_num)
            ]
        ).to(device)
        
        prediction_list = []
        
        with torch.no_grad():
            for model in model_list:
                with torch.no_grad():
                    if not args.pred_neg_type:
                        prediction_list.append([float(el) for el in model(ids, mask).cpu().numpy()])
                    else:
                        prediction_list.append([float(el) for el in model(ids, mask)[0].cpu().numpy()])

            prediction_list = np.array(prediction_list)
            pred_list_for_current_context = np.mean(prediction_list, 0)
            uncertainty_list_for_current_context = np.var(prediction_list, 0)

        pred_list_for_current_context = [float(el) for el in pred_list_for_current_context]
        uncertainty_list_for_current_context = [float(el) for el in uncertainty_list_for_current_context]
        
        assert (
            len(pred_list_for_current_context)
            == len(uncertainty_list_for_current_context)
            == args.test_candidate_num
        )

        total_item_list.append(
            {
                "pred": pred_list_for_current_context,
                "uncertainty": uncertainty_list_for_current_context,
            }
        )

    with open(args.output_fname, "w") as f:
        for item in total_item_list:
            json.dump(item, f)
            f.write("\n")


if __name__ == "__main__":
    args = parser.parse_args()

    post_fix = str(args.random_seed)
    
    if args.hardneg_smoothing > 0:
        post_fix += "_hardsmooth{}".format(args.hardneg_smoothing)
    if args.triplet_alpha > 0:
        post_fix += "_margin{}alpha{}".format(args.triplet_margin, args.triplet_alpha)
    if args.ntxent_alpha > 0:
        post_fix += "_ntxenttemp{}alpha{}".format(args.ntxent_temp, args.ntxent_alpha)
    
    args.pred_neg_type = args.pred_neg_type == "True"
    if args.pred_neg_type:
        post_fix += f"_predneg{str(args.pred_neg_type_alpha)}"

    # default path = "logs_{}/{}_{}_batch32_candi{}_hard{}_seed{}{}/model"
    args.model_path = args.model_path.format(
        args.lmtype,
        args.neg_type,
        args.approach,
        args.train_num_candidates,
        args.train_num_hard_negs,
        args.random_seed,
        ("_curriculum" if args.is_curriculum else ""),
        ("_nshuffle" if not args.is_shuffle else ""),
    )
    args.exp_name = "{}_{}_hard{}{}{}-{}_{}-test_candi{}_test".format(
        args.neg_type,
        args.approach,
        args.train_num_hard_negs,
        ("_curriculum" if args.is_curriculum else ""),
        ("_nshuffle" if not args.is_shuffle else ""),
        args.test_neg_type,
        args.test_approach,
        args.test_candidate_num
    )

    if args.hardneg_smoothing > 0:
        args.exp_name += "-hardsmooth{}".format(args.hardneg_smoothing)
    if args.triplet_alpha > 0:
        args.exp_name += "-margin{}alpha{}".format(args.triplet_margin, args.triplet_alpha)
    if args.ntxent_alpha > 0:
        args.exp_name += "-ntexenttemp{}alpha{}".format(args.ntxent_temp, args.ntxent_alpha)
    if args.pred_neg_type:
        args.exp_name += "-predneg{}".format(args.pred_neg_type_alpha)

    args.log_path = os.path.join(args.log_path, args.corpus)

    os.makedirs(args.log_path, exist_ok=True)
    args.output_fname = os.path.join(args.log_path, args.exp_name) + ".json"
    assert not os.path.exists(args.output_fname)

    os.makedirs(os.path.dirname(args.output_fname), exist_ok=True)

    print("\n", args.output_fname, "\n")

    main(args)
