import copy
import sys
from os import path

import wandb
import yaml

sys.path.append(
    path.dirname(path.abspath(path.dirname(path.abspath(path.dirname(__file__)))))
)

import os.path as op
import time, json
import torch
import torch.nn.functional as F
import torch.cuda.amp as amp

from torch.cuda.amp import GradScaler, autocast
from tqdm import tqdm

from utils.logger import setup_loguru_logger
from utils.tsv_file_ops import (tsv_writer, concat_tsv_files,
                                delete_tsv_files, reorder_tsv_keys)
from utils.misc import (set_seed)
from utils.caption_evaluate import (evaluate_on_coco_caption,
                                    evaluate_on_nocaps, ScstRewardCriterion)
from utils.distributed_processing import (
    get_world_size, get_rank, is_main_process, synchronize, ensure_init_process_group, all_reduce_tensor
)

from utils.wandb_setup import wandb_setup
from utils.param_utils import compute_param_norm

from modeling.modeling_bert import ImageBertForSequenceClassification

from pytorch_transformers import BertTokenizer, BertConfig
from pytorch_transformers import AdamW, WarmupLinearSchedule, WarmupConstantSchedule

from args.args_vqa import get_arguments, save_arg
from task.vqa.dataloader import make_data_loader
from task.vqa.IO import save_checkpoint, get_predict_file, get_evaluate_file, save_best, save_predict, \
    concat_cache_files
from task.vqa.factory import model_factory

def compute_MLM_score_with_logits(logits, labels):
    logits = torch.max(logits, -1)[1].data  # argmax
    scores = logits == labels
    return scores

def compute_CLS_score_with_logtis(logits, labels):
    one_hots = F.one_hot(torch.argmax(logits, 1), num_classes=labels.shape[-1])
    scores = (one_hots * labels)
    return scores

def train(args, train_dataloader, val_dataloader, model, tokenizer) -> str:
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True,
        )

    # -------- calc iterations -------- #
    iters_per_epoch = len(train_dataloader) // args.gradient_accumulation_steps
    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // iters_per_epoch + 1
    else:
        t_total = iters_per_epoch * args.num_train_epochs

    # -------- optimizer -------- #
    no_decay = ['bias', 'LayerNorm.weight']
    grouped_parameters = [
        {'params': [p for n, p in model.named_parameters() if not \
            any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay},
        {'params': [p for n, p in model.named_parameters() if \
                    any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
    ]
    optimizer = AdamW(grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)

    warmup_steps = args.warmup_step_ratio * args.num_train_epochs * iters_per_epoch
    plateau_steps = args.plateau_step_ratio * args.num_train_epochs * iters_per_epoch

    # -------- scheduler -------- #
    if args.scheduler == "constant":

        scheduler = WarmupConstantSchedule(
            optimizer,
            warmup_steps=warmup_steps
        )
        logger.info(f"Constant Scheduler ({t_total}) with Warm Up ({warmup_steps}) Steps")

    elif args.scheduler == "linear":
        scheduler = WarmupLinearSchedule(
            optimizer,
            warmup_steps=warmup_steps,
            t_total=t_total
        )
        logger.info(f"Warm Up Linear Scheduler ({t_total}) with Warm Up ({warmup_steps}) Steps")

    elif args.scheduler == "plateau_linear":
        scheduler = WarmupLinearSchedule(
            optimizer,
            warmup_steps=warmup_steps,
            plateau_steps=plateau_steps,
            t_total=t_total
        )
        logger.info(f"Warm Up Linear Scheduler({t_total})" + \
                    f" with Warm up ({warmup_steps}) and Plateau ({plateau_steps}) Steps")
    else:
        raise ValueError("Unknown scheduler type: {}".format(args.scheduler))

    # -------- logging -------- #
    logger.info("***** Running training *****")
    logger.info("  Num Epochs = " + f"{args.num_train_epochs}")
    logger.info("  Batch size per GPU = " + f"{args.per_gpu_train_batch_size}")
    logger.info("  Total train batch size (w. parallel, & accumulation) = " + \
                f"{args.per_gpu_train_batch_size * get_world_size() * args.gradient_accumulation_steps}")
    logger.info("  Gradient Accumulation steps = "+ f"{args.gradient_accumulation_steps}")
    logger.info("  Total optimization steps = " + f"{t_total}")

    # -------- FP16 -------- #
    scaler = GradScaler(init_scale=2048, growth_interval=1000, enabled=args.fp16)

    # -------- initialize the Misc.-------- #
    global_step, global_loss, global_acc = 0, 0.0, 0.0
    model.zero_grad()
    checkpoint_dir = ""
    best_score = 0.0
    best_ckpt_dir = ""

    # -------- training -------- #
    for epoch in range(int(args.num_train_epochs)):
        for step, (img_keys, batch) in enumerate(train_dataloader):

            model.train()
            for key in batch.keys():
                batch[key] = batch[key].to(args.device)

            with autocast(enabled=args.fp16):
                outputs = model(**batch)

            loss, logits = outputs[:2]

            # -------- calc accuracy -------- #
            if args.MLM_train:
                masked_ids = batch['masked_ids']
                masked_ids = masked_ids[masked_ids != 0]
                batch_score = compute_MLM_score_with_logits(logits, masked_ids)
                batch_sum = torch.sum(batch_score)
                batch_mask_num = torch.sum(batch['masked_pos'])
                batch_acc = all_reduce_tensor(batch_sum, op="sum")\
                            / all_reduce_tensor(batch_mask_num, op="sum")
            else:
                batch_score = compute_CLS_score_with_logtis(logits, batch['labels'])
                batch_sum = torch.sum(batch_score)
                batch_mask_num = torch.tensor(batch['labels'].shape[0], device=args.device)
                batch_acc = all_reduce_tensor(batch_sum, op="sum")\
                            / all_reduce_tensor(batch_mask_num, op="sum")

            # -------- grad accumulation -------- #
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps
                batch_acc = batch_acc / args.gradient_accumulation_steps

            scaler.scale(loss).backward()
            global_loss += loss.item()
            global_acc += batch_acc

            if (step + 1) % args.gradient_accumulation_steps == 0:

                # -------- optimizer step -------- #
                global_step += 1
                scaler.unscale_(optimizer)
                grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
                scaler.step(optimizer)
                scaler.update()

                scheduler.step()
                model.zero_grad()

                param_norm = compute_param_norm(model.parameters())
                # -------- logging -------- #
                if global_step % args.logging_steps == 0:
                    logger.info("Epoch: {}, global_step: {}, lr: {:.6f}, loss: {:.4f} ({:.4f}), " \
                                "score: {:.4f} ({:.4f})".format(epoch,
                                                                global_step,
                                                                optimizer.param_groups[0]["lr"],
                                                                loss*args.gradient_accumulation_steps,
                                                                global_loss / global_step,
                                                                batch_acc*args.gradient_accumulation_steps,
                                                                global_acc / global_step)
                                )
                    if is_main_process():
                        wandb.log({
                            "Epoch": epoch,
                            "step": step,
                            "global_step": global_step,
                            "lr": optimizer.param_groups[0]["lr"],
                            "loss": loss*args.gradient_accumulation_steps,
                            "global_loss": global_loss / global_step,
                            "accuracy": batch_acc*args.gradient_accumulation_steps,
                            "global_accuracy": global_acc / global_step,
                            "grad_norm": grad_norm,
                            "param_norm": param_norm
                        })

                # -------- SAVE and EVAL -------- #
                if (args.save_steps > 0 and global_step % args.save_steps == 0) or \
                        global_step == t_total:
                    checkpoint_dir = save_checkpoint(logger, model, tokenizer, args, epoch, global_step)

                    # evaluation
                    if args.evaluate_during_training:
                        logger.info(f"Perform evaluation at step: {global_step}")
                        score = evaluate(args,val_dataloader,model)
                        logger.info("Epoch {} Eval score: {:.4f}".format(epoch, score))

                        # Wandb log for the eval phase.
                        if is_main_process():
                            wandb.log({
                                "Epoch": epoch,
                                "step": step,
                                "global_step": global_step,
                                "eval_score":  score
                            })

                        # Comparing the scores
                        if score > best_score:
                            best_score = score
                            best_ckpt_dir = checkpoint_dir
                            logger.info("Update best model.")

                        synchronize()

    # -------- copy best model to best folder -------- #
    if is_main_process():
        save_best(logger, args.output_dir, best_ckpt_dir)

    synchronize()
    return checkpoint_dir


def evaluate(args, val_dataloader, model):
    logger.info("***** Start Evaluation *****")
    score = test(args, val_dataloader, model)
    synchronize()
    return score


def test(args, test_dataloader, model, predict_file=None):

    model.eval()
    time_meter = 0
    batch_sum = torch.tensor(0.0).to(args.device)
    results = []
    with torch.no_grad():
        for (q_ids, batch) in tqdm(test_dataloader, total=len(test_dataloader)):

            for key in batch.keys():
                batch[key] = batch[key].to(args.device)

            if args.num_tags is not None:
                start_idx = 20 + args.num_tags
                batch["attention_mask"][..., start_idx:50].fill_(0)

            tic = time.time()
            with autocast(enabled=args.fp16):
                outputs = model(**batch)

            synchronize()
            time_meter += time.time() - tic

            if predict_file is None:
                loss, logits = outputs[:2]  # logits
                batch_score = compute_CLS_score_with_logtis(logits, batch['labels'])
                batch_sum += all_reduce_tensor(torch.sum(batch_score),op="sum")

            else:
                logits = outputs[0]  # logits
                val, a_ids = logits.max(1)
                for q_id, a_id in zip(q_ids, a_ids.tolist()):
                    results.append(
                        {
                            "question_id" : int(q_id),
                            "answer" : test_dataloader.dataset.label2ans[a_id]
                        }
                    )

    logger.info("Inference model computing time: {} seconds".format(time_meter))

    synchronize()
    if predict_file is None:
        score = batch_sum / len(test_dataloader.dataset)
        return score

    else:
        world_size = get_world_size()
        if world_size == 1:
            save_predict(results, predict_file)
            logger.info("Inference file saved")
            return

        else:
            cache_file = op.splitext(predict_file)[0] \
                         + f'_{get_rank()}_{world_size}'\
                         + op.splitext(predict_file)[1]

            save_predict(results, cache_file)
            synchronize()

            if is_main_process():
                cache_files = [op.splitext(predict_file)[0] + '_{}_{}'.format(i, world_size) + \
                               op.splitext(predict_file)[1] for i in range(world_size)]
                concat_cache_files(cache_files, predict_file)

    synchronize()
    return

def main():
    # Setup global logger
    global logger

    # Get argumnets
    args = get_arguments()

    # Setup CUDA, GPU & distributed training
    local_rank = ensure_init_process_group()
    args.local_rank = local_rank
    args.num_gpus = get_world_size()
    args.distributed = args.num_gpus > 1
    args.device = torch.device('cuda')
    synchronize()

    # Setup Logger
    # logger = setup_logger("vlpretrain", output_dir, args.local_rank)
    logger = setup_loguru_logger("vlpretrain", args.output_dir, args.local_rank)
    logger.warning(f"Device: {args.device}, n_gpu: {args.num_gpus}")
    set_seed(args.seed, args.num_gpus)

    # Logging the sequence length information
    logger.info("VQAv2")
    if args.add_prefix is True:
        args.max_seq_length = args.max_seq_a_length + args.num_prefix
        logger.info("Model with [prefix] token")
        logger.warning(f'max_seq_length ({args.max_seq_length}) = '
                       f'question length ({args.max_seq_a_length}) + '
                       f'[prefix] token ({args.num_prefix})')
    else:
        args.max_seq_b_length = args.max_seq_length - args.max_seq_a_length
        logger.info("Model with object tags")
        logger.warning(f'max_seq_length ({args.max_seq_length}) = '
                       f'question length ({args.max_seq_a_length}) + '
                       f'object tags ({args.max_seq_b_length})')


    # WandB initialize
    if args.do_train:
        args.output_dir = wandb_setup(vars(args))

        if is_main_process():
            new_cfg_file = save_arg(args=vars(args),output_dir=args.output_dir)
            logger.info(f"new config saved to {new_cfg_file}")

    # Load pretrained model and tokenizer
    config_class, model_class, tokenizer_class = BertConfig, ImageBertForSequenceClassification, BertTokenizer
    model, tokenizer, checkpoint = model_factory(logger, args, config_class, model_class, tokenizer_class)
    model.to(args.device)

    # Logging the parameters
    if is_main_process() is True:
        print_args = copy.deepcopy(vars(args))
        print_args["device"] = "cuda" if print_args["device"].type == "cuda" else "cpu"
        logger.info(f"Training/evaluation parameters : \n{yaml.dump(print_args)}")

    # Setup Data loader with arguments
    if args.do_train:
        train_dataloader = make_data_loader(args=args,
                                            yaml_file=args.train_yaml,
                                            tokenizer=tokenizer,
                                            logger=logger,
                                            is_distributed=args.distributed,
                                            is_train=True)
        val_dataloader = None
        if args.evaluate_during_training:
            val_dataloader = make_data_loader(args=args,
                                              yaml_file=args.val_yaml,
                                              tokenizer=tokenizer,
                                              logger=logger,
                                              is_distributed=args.distributed,
                                              is_train=False)
        last_checkpoint = train(args, train_dataloader, val_dataloader, model, tokenizer)

        # test the last checkpoint after training
        if args.do_test:
            logger.info("Evaluate on dataset: " + args.test_yaml)
            test_dataloader = make_data_loader(args=args,
                                               yaml_file=args.test_yaml,
                                               logger=logger,
                                               tokenizer=tokenizer,
                                               is_distributed=args.distributed,
                                               is_train=False)
            predict_file = get_predict_file(checkpoint, test_dataloader.dataset.yaml_file, args)
            test(args, test_dataloader, model, predict_file)

    # inference and evaluation
    elif args.do_test or args.do_eval:

        if not args.do_eval:
            logger.info("Evaluate on dataset: " + args.test_yaml)
            test_dataloader = make_data_loader(args=args,
                                               yaml_file=args.test_yaml,
                                               logger=logger,
                                               tokenizer=tokenizer,
                                               is_distributed=args.distributed, is_train=False)


            predict_file = get_predict_file(checkpoint, test_dataloader.dataset.yaml_file, args)
            test(args, test_dataloader, model, predict_file)
            logger.info("Prediction results saved to: {}".format(predict_file))
        else:
            logger.info("Evaluate on dataset: " + args.val_yaml)
            test_dataloader = make_data_loader(args=args,
                                               yaml_file=args.val_yaml,
                                               logger=logger,
                                               tokenizer=tokenizer,
                                               is_distributed=args.distributed, is_train=False)

            eval_score = evaluate(args, test_dataloader, model)
            logger.info("Evaluation score: {}".format(eval_score))


if __name__ == "__main__":
    main()
