import argparse
import os
import sys

import torch
import numpy as np
from torch import nn
from tensorboardX import SummaryWriter
from torch.nn.modules.loss import CrossEntropyLoss
from torch.optim import Adam, AdamW
from torch.utils.data import DataLoader
from torch.utils.data.dataloader import DataLoader
from tqdm import tqdm
from transformers import (
    AutoModel,
    AutoTokenizer,
)
from response_selection.preprocessing import get_syndd_corpus
from response_selection.model import ResponseSelection
from response_selection.dataset import SelectionDataset
from response_selection.utils import (
    PREFIX_DIR,
    UTTR_TOKEN,
    LMDICT,
    set_logger,
    set_random_seed,
    dump_config,
    save_model,
    write2tensorboard,
)

parser = argparse.ArgumentParser(description="Process arguments for training a response selection model")

# curriculum learning
parser.add_argument("--is-curriculum", action="store_true", help="default=False, given=True")
parser.add_argument("--is_shuffle", action="store_true", help="default=False, given=True")

parser.add_argument("--random-seed", type=int, default=42)
parser.add_argument("--log-path", type=str, default="logs")
parser.add_argument("--dataset-path", type=str, default="dailydialog_gpt", choices=["dailydialog_gpt", "dailydialog_syn", "dailydialog++"],)
parser.add_argument("--is-distributed", action="store_false", help="if the argument is given, set up a distributed training")
parser.add_argument("--batch-size", type=int, default=32)
parser.add_argument("--learning-rate", type=float, default=2e-5)
parser.add_argument("--epoch", type=int, default=3)
parser.add_argument("--retrieval-candidate-num", type=int, default=11, help="Number of candidates including golden response")
parser.add_argument("--hard-neg-num", type=int, default=5, help="Number of adversarial negatives within retrieval candidates")
parser.add_argument("--grad-max-norm", type=float, default=1.0, help="gradient max norm value with Gradient Clipping")
parser.add_argument("--is-softlabel", type=bool, default=False)

# Arguments for experiment setup
parser.add_argument("--lmtype", type=str, default="bert", choices=["bert", "roberta", "electra"])
parser.add_argument("--neg-type", type=str, default="gpt", choices=["random", "human", "gpt", "syn"])
parser.add_argument(
    "--target-object",
    type=str,
    default="gpt3_negative_responses",
    choices=[
        "random_sampled",  # random : none
        "bm25_sampled",  # syn : bm25
        "adv_gen_neg_responses_t1",  # syn : maskandfill, kwsim
        "adversarial_negative_responses",  # human : none
        "gpt3_negative_responses",  # gpt : direct_w_ans, direct_wo_ans, direct_0_shot, meta
        "semi_hard",  # random: "semi_hard"
    ],
)
parser.add_argument(
    "--approach",
    type=str,
    default="none",
    choices=[
        "none",
        "bm25",
        "maskandfill",
        "kwsim",
        "direct_w_ans",
        "direct_wo_ans",
        "direct_wo_ans_10p",  # DTI by using 10% of human-written samples
        "direct_wo_ans_1p",  # DTI by using 1% of human-written samples
        "direct_wo_ans_01p",  # DTI by using 0.1% of human-written samples
        "direct_wo_ans_01p_reuse",  # DTI by using 0.1% of human-written samples + reuse generated
        "direct_aug_10000",  # DTI by adding augmented 10,000 dataset
        "direct_1_shot",
        "direct_0_shot",
        "meta",
        "semi_hard",
    ],
)


def evaluation(
    device,
    args: argparse.Namespace,
    writer: SummaryWriter,
    global_step: int,
    model: nn.Module,
    crossentropy: CrossEntropyLoss,
    validloader: DataLoader,
):
    """
    evaluating validation dataset with response selection model
    
    :param device: device
    :param args: training arguments
    :param writer: SummaryWriter for tensorboard
    :param global_step: current global step
    :param model: current model
    :param crossentropy: cross entropy loss function
    :param valid_loader: dataloader for validation datset
    """
    model.eval()
    loss_list = []

    with torch.no_grad():
        for batch in tqdm(validloader):
            ids_list = batch[: args.retrieval_candidate_num]
            mask_list = batch[args.retrieval_candidate_num : 2 * args.retrieval_candidate_num]
            labels = batch[2 * args.retrieval_candidate_num]
            
            batch_size = labels.shape[0]
            labels = labels.to(device)
            ids_list = torch.cat(ids_list, 1).reshape(batch_size * args.retrieval_candidate_num, 128).to(device)
            mask_list = torch.cat(mask_list, 1).reshape(batch_size * args.retrieval_candidate_num, 128).to(device)

            output = model(ids_list, mask_list)
            output = output.reshape(batch_size, -1)

            loss = crossentropy(output, labels)
            loss_list.append(loss)

        final_loss = sum(loss_list) / len(loss_list)
        write2tensorboard(writer, {"loss": final_loss}, "valid", global_step)


def run(args: argparse.Namespace):
    """
    Training pipeline

    :param args: training arguments
    """
    # set random seed
    set_random_seed(args.random_seed)

    # set logger
    logger = set_logger(f"response-selection-model-training")

    # save model config
    dump_config(args)
    
    # record and report params
    for k, v in vars(args).items():
        logger.info(f"{k}: {v}")

    device = torch.device(f"cuda:0" if torch.cuda.is_available() else "cpu")
    model_name = LMDICT[args.lmtype]

    # set tokenizer
    logger.info(f"[+] Load Tokenizer: {model_name}")
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    special_tokens_dict = {"additional_special_tokens": [UTTR_TOKEN]}
    tokenizer.add_special_tokens(special_tokens_dict)

    # load pretrained model weight
    logger.info(f"[+] Load Pretrained Model: {model_name}")
    pretrained_model = AutoModel.from_pretrained(model_name)
    pretrained_model.resize_token_embeddings(len(tokenizer))

    # load response selection model
    logger.info(f"[+] Load Model: ResponseSelection")
    model = ResponseSelection(pretrained_model)

    # distributed training setup
    if args.is_distributed:
        logger.info(f"[+] Distributed Training Setup")
        model = torch.nn.DataParallel(model)
    model.to(device)

    # load dataset
    logger.info(f"[+] Load Dataset: {args.dataset_path} with {args.approach} approach")
    raw_dd_train = get_syndd_corpus(args.dataset_path, "train", args.approach)
    raw_dd_dev = get_syndd_corpus("dailydialog++", "valid", args.approach)

    text_train_fname = (
        PREFIX_DIR
        + "data/selection_{}/text_cand{}_hard{}".format(args.lmtype, args.retrieval_candidate_num, args.hard_neg_num)
        + "_{}".format(args.neg_type)
        + "_{}".format(args.approach)
        + ("_curriculum" if args.is_curriculum else "")
        + "_{}.pck"
    )
    pickle_train_fname = (
        PREFIX_DIR
        + "data/selection_{}/tensor_cand{}_hard{}".format(args.lmtype, args.retrieval_candidate_num, args.hard_neg_num)
        + "_{}".format(args.neg_type)
        + "_{}".format(args.approach)
         + ("_curriculum" if args.is_curriculum else "")
        + "_{}.pck"
    )
    text_dev_fname = (
        PREFIX_DIR
        + "data/selection_{}/text_cand{}_hard{}".format(args.lmtype, args.retrieval_candidate_num, 5)
        + "_human_none_dev.pck"
    )
    pickle_dev_fname = (
        PREFIX_DIR
        + "data/selection_{}/tensor_cand{}_hard{}".format(args.lmtype, args.retrieval_candidate_num, 5)
        + "_human_none_dev.pck"
    )

    train_dataset = SelectionDataset(
        raw_dataset=raw_dd_train,
        tokenizer=tokenizer,
        setname="train",
        target_object=args.target_object,
        max_seq_len=128,
        num_candidates=args.retrieval_candidate_num,
        num_hard_negs=args.hard_neg_num,
        is_curriculum=args.is_curriculum,
        uttr_token=UTTR_TOKEN,
        txt_save_fname=text_train_fname,
        tensor_save_fname=pickle_train_fname,
    )

    dev_dataset = SelectionDataset(
        raw_dataset=raw_dd_dev,
        tokenizer=tokenizer,
        setname="dev",
        target_object="adversarial_negative_responses",
        max_seq_len=128,
        num_candidates=args.retrieval_candidate_num,
        num_hard_negs=5,
        is_curriculum=False,
        uttr_token=UTTR_TOKEN,
        txt_save_fname=text_dev_fname,
        tensor_save_fname=pickle_dev_fname,
    )

    g = torch.Generator()
    g.manual_seed(args.random_seed)

    trainloader = DataLoader(
        train_dataset,
        shuffle=args.is_shuffle,
        batch_size=args.batch_size,
        drop_last=True,
        generator=g,
    )
    validloader = DataLoader(
        dev_dataset, batch_size=args.batch_size, drop_last=True
    )

    logger.info("[+] Dataset and Dataloader Statistics...")
    logger.info(f"[+] Train examples: {len(train_dataset)} / steps: {len(trainloader)}")
    logger.info(f"[+] Dev examples: {len(dev_dataset)} / steps: {len(validloader)}")

    logger.info("[+] Start Training")
    
    crossentropy = CrossEntropyLoss()
    optimizer = Adam(model.parameters(), lr=args.learning_rate)

    writer = SummaryWriter(args.board_path)

    global_step = 0
    optimizer.zero_grad()

    for epoch in range(args.epoch):
        model.train()
        for batch in tqdm(trainloader):
            ids_list = batch[: args.retrieval_candidate_num]
            mask_list = batch[args.retrieval_candidate_num : 2 * args.retrieval_candidate_num]
            labels = batch[2 * args.retrieval_candidate_num]
            
            batch_size = labels.shape[0]
            ids_list = torch.cat(ids_list, 1).reshape(batch_size * args.retrieval_candidate_num, 128).to(device)
            mask_list = torch.cat(mask_list, 1).reshape(batch_size * args.retrieval_candidate_num, 128).to(device)
            labels = labels.to(device)

            output = model(ids_list, mask_list)
            output = output.reshape(batch_size, -1)

            loss = crossentropy(output, labels)
            loss.backward()
            
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_max_norm)
            optimizer.step()
            write2tensorboard(writer, {"loss": loss}, "train", global_step)
            global_step += 1

            optimizer.zero_grad()

        # evaluation with dev dataset
        logger.info("[+] Start Evaluation")
        evaluation(
            device=device,
            args=args,
            writer=writer,
            global_step=global_step,
            model=model,
            crossentropy=crossentropy,
            validloader=validloader
        )

        # save the model weight
        save_model(model, epoch, args.model_path, args.is_distributed)


def main():
    args = parser.parse_args()

    args.log_path = PREFIX_DIR + args.log_path + "_" + args.lmtype
    args.exp_name = args.neg_type + "_{}_batch{}_candi{}_hard{}_seed{}".format(
        args.approach,
        args.batch_size,
        args.retrieval_candidate_num,
        args.hard_neg_num,
        args.random_seed,
    )

    # curriculum
    if args.is_curriculum:
        args.exp_name += "_curriculum"
    if not args.is_shuffle:
        args.exp_name += "_nshuffle"

    args.exp_path = os.path.join(args.log_path, args.exp_name)
    args.model_path = os.path.join(args.exp_path, "model")
    args.board_path = os.path.join(args.exp_path, "board")
    
    os.makedirs(args.model_path, exist_ok=True)
    os.makedirs(args.board_path, exist_ok=True)

    run(args)

if __name__ == "__main__":
    sys.exit(main())
