import copy
import sys
from os import path

import wandb
import yaml

sys.path.append(
    path.dirname(path.abspath(path.dirname(path.abspath(path.dirname(__file__)))))
)

import os.path as op
import time, json
import torch
from tqdm import tqdm

from utils.logger import setup_loguru_logger
from utils.tsv_file_ops import (tsv_writer, concat_tsv_files,
                                delete_tsv_files, reorder_tsv_keys)
from utils.misc import (set_seed)
from utils.caption_evaluate import (evaluate_on_coco_caption,
                                    evaluate_on_nocaps, ScstRewardCriterion)
from utils.distributed_processing import (
    get_world_size, get_rank, is_main_process, synchronize, ensure_init_process_group
)

from utils.wandb_setup import wandb_setup
from utils.param_utils import compute_param_norm

from modeling.modeling_bert import BertForImageCaptioning

from pytorch_transformers import BertTokenizer, BertConfig
from pytorch_transformers import AdamW, WarmupLinearSchedule, WarmupConstantSchedule

from args.args_captioning import get_arguments, restore_training_settings, save_arg
from task.captioning.dataloader import make_data_loader
from task.captioning.IO import save_checkpoint, get_predict_file, get_evaluate_file, save_best
from task.captioning.factory import model_factory

def compute_score_with_logits(logits, labels):
    logits = torch.max(logits, -1)[1].data  # argmax
    scores = logits == labels
    return scores


def train(args, train_dataloader, val_dataloader, test_dataloader, model, tokenizer) -> str:
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True,
        )

    iters_per_epoch = len(train_dataloader) // args.gradient_accumulation_steps
    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // iters_per_epoch + 1
    else:
        t_total = iters_per_epoch * args.num_train_epochs

    # Prepare optimizer and scheduler
    no_decay = ['bias', 'LayerNorm.weight']
    grouped_parameters = [
        {'params': [p for n, p in model.named_parameters() if not \
            any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay},
        {'params': [p for n, p in model.named_parameters() if \
                    any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
    ]
    optimizer = AdamW(grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)

    warmup_steps = args.warmup_step_ratio * args.num_train_epochs * iters_per_epoch
    plateau_steps = args.plateau_step_ratio * args.num_train_epochs * iters_per_epoch

    if args.scheduler == "constant":

        scheduler = WarmupConstantSchedule(
            optimizer,
            warmup_steps=warmup_steps
        )
        logger.info(f"Constant Scheduler ({t_total}) with Warm Up ({warmup_steps}) Steps")

    elif args.scheduler == "linear":
        scheduler = WarmupLinearSchedule(
            optimizer,
            warmup_steps=warmup_steps,
            t_total=t_total
        )
        logger.info(f"Warm Up Linear Scheduler ({t_total}) with Warm Up ({warmup_steps}) Steps")

    elif args.scheduler == "plateau_linear":
        scheduler = WarmupLinearSchedule(
            optimizer,
            warmup_steps=warmup_steps,
            plateau_steps=plateau_steps,
            t_total=t_total
        )
        logger.info(f"Warm Up Linear Scheduler({t_total})" + \
                    f" with Warm up ({warmup_steps}) and Plateau ({plateau_steps}) Steps")
    else:
        raise ValueError("Unknown scheduler type: {}".format(args.scheduler))

    logger.info("***** Running training *****")
    logger.info("  Num Epochs = " + f"{args.num_train_epochs}")
    logger.info("  Batch size per GPU = " + f"{args.per_gpu_train_batch_size}")
    logger.info("  Total train batch size (w. parallel, & accumulation) = " + \
                f"{args.per_gpu_train_batch_size * get_world_size() * args.gradient_accumulation_steps}")
    logger.info("  Gradient Accumulation steps = "+ f"{args.gradient_accumulation_steps}")
    logger.info("  Total optimization steps = " + f"{t_total}")

    if args.scst:
        scst_criterion = ScstRewardCriterion(
            cider_cached_tokens=op.join(args.data_dir, args.cider_cached_tokens),
            baseline_type=args.sc_baseline_type,
        )
        logger.info("  SCST training...")

    global_step, global_loss, global_acc = 0, 0.0, 0.0
    model.zero_grad()

    eval_log = []
    checkpoint_dir = ""

    best_score = 0
    best_ckpt_dir = ""

    for epoch in range(int(args.num_train_epochs)):
        for step, (img_keys, batch) in enumerate(train_dataloader):
            batch = tuple(t.to(args.device) for t in batch)

            if not args.scst:
                model.train()
                inputs = {'input_ids': batch[0], 'attention_mask': batch[1],
                          'token_type_ids': batch[2], 'img_feats': batch[3],
                          'masked_pos': batch[4], 'masked_ids': batch[5]
                          }

                outputs = model(**inputs)
                loss, logits = outputs[:2]

                masked_ids = inputs['masked_ids']
                masked_ids = masked_ids[masked_ids != 0]

                batch_score = compute_score_with_logits(logits, masked_ids)
                batch_acc = torch.sum(batch_score.float()) / torch.sum(inputs['masked_pos'])
            else:
                loss = scst_train_iter(args, train_dataloader, model, scst_criterion, img_keys, batch, tokenizer)
                batch_acc = scst_criterion.get_score()

            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps
                batch_acc = batch_acc / args.gradient_accumulation_steps

            loss.backward()

            global_loss += loss.item()
            global_acc += batch_acc

            if (step + 1) % args.gradient_accumulation_steps == 0:
                grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
                param_norm = compute_param_norm(model.parameters())

                global_step += 1
                optimizer.step()
                scheduler.step()
                model.zero_grad()

                if global_step % args.logging_steps == 0:
                    logger.info("Epoch: {}, global_step: {}, lr: {:.6f}, loss: {:.4f} ({:.4f}), " \
                                "score: {:.4f} ({:.4f})".format(epoch,
                                                                global_step,
                                                                optimizer.param_groups[0]["lr"],
                                                                loss*args.gradient_accumulation_steps,
                                                                global_loss / global_step,
                                                                batch_acc*args.gradient_accumulation_steps,
                                                                global_acc / global_step)
                                )
                    if is_main_process():
                        wandb.log({
                            "Epoch": epoch,
                            "step": step,
                            "global_step": global_step,
                            "lr": optimizer.param_groups[0]["lr"],
                            "loss": loss*args.gradient_accumulation_steps,
                            "global_loss": global_loss / global_step,
                            "accuracy": batch_acc*args.gradient_accumulation_steps,
                            "global_accuracy": global_acc / global_step,
                            "grad_norm": grad_norm,
                            "param_norm": param_norm
                        })

                if (args.save_steps > 0 and global_step % args.save_steps == 0) or \
                        global_step == t_total:
                    checkpoint_dir = save_checkpoint(logger, model, tokenizer, args, epoch, global_step)

                    # evaluation
                    if args.evaluate_during_training:
                        logger.info(f"Perform evaluation at step: {global_step}")
                        val_evaluate_file = evaluate(args,
                                                 val_dataloader,
                                                 model,
                                                 tokenizer,
                                                 checkpoint_dir)

                        val_data_name = val_dataloader.dataset.yaml_file.split('/')[-2]

                        with open(val_evaluate_file, 'r') as f:
                            val_res_file = json.load(f)

                        if 'nocaps' in val_data_name:
                            # Log only 'entire score' for the nocaps eval
                            res_temp = {key:val_res_file[key]["entire"] for key in val_res_file.keys()}
                            val_res = { # Re-organize the dictionary : because some keys are different
                                "CIDEr"  : res_temp['CIDEr'],
                                "Bleu_4" : res_temp['B4'],
                                "METEOR" : res_temp['METEOR'],
                                "ROUGE-L": res_temp['ROUGE-L'],
                                "SPICE"  : res_temp['SPICE'],
                            }

                        else:
                            # Log all for coco eval
                            val_res = val_res_file
                            test_res = test

                        # Wandb log for the eval phase.
                        if is_main_process():
                            wandb.log({
                                "Epoch": epoch,
                                "step": step,
                                "global_step": global_step,
                                "CIDEr": val_res['CIDEr'],
                                "BLEU4": val_res['Bleu_4'],
                                "METEOR": val_res['METEOR'],
                                "ROUGE": val_res['ROUGE_L'],
                                "SPICE": val_res['SPICE'],
                            })

                        if val_res['CIDEr'] > best_score:
                            best_score = val_res['CIDEr']
                            best_ckpt_dir = checkpoint_dir

                        val_res['epoch'] = epoch
                        val_res['global_step'] = step
                        val_res['best_CIDEr'] = best_score

                        eval_log.append(val_res)

                        # Save the eval_logs
                        if is_main_process():
                            with open(args.output_dir + '/eval_logs.json', 'w') as f:
                                json.dump(eval_log, f)

                        # test split evaluate
                        if test_dataloader is not None:
                            test_evaluate_file = evaluate(args,
                                                          test_dataloader,
                                                          model,
                                                          tokenizer,
                                                          checkpoint_dir)

                            with open(test_evaluate_file, 'r') as f:
                                test_res_file = json.load(f)

                            if is_main_process():
                                wandb.log({
                                    "Epoch": epoch,
                                    "step": step,
                                    "global_step": global_step,
                                    "test_CIDEr": test_res_file['CIDEr'],
                                    "test_BLEU4": test_res_file['Bleu_4'],
                                    "test_METEOR": test_res_file['METEOR'],
                                    "test_ROUGE": test_res_file['ROUGE_L'],
                                    "test_SPICE": test_res_file['SPICE'],
                                })

                        synchronize()

    # Save best model to best folder
    if is_main_process():
        save_best(logger, args.output_dir, best_ckpt_dir)

    synchronize()
    return checkpoint_dir


def scst_train_iter(args, train_dataloader, model, scst_criterion,
                    img_keys, batch, tokenizer):
    cls_token_id, sep_token_id, pad_token_id, mask_token_id = \
        tokenizer.convert_tokens_to_ids([tokenizer.cls_token,
                                         tokenizer.sep_token, tokenizer.pad_token, tokenizer.mask_token]
                                        )
    inputs = {'is_decode': True,
              'input_ids': batch[0], 'attention_mask': batch[1],
              'token_type_ids': batch[2], 'img_feats': batch[3],
              'masked_pos': batch[4],
              'do_sample': False,
              'bos_token_id': cls_token_id,
              'pad_token_id': pad_token_id,
              'eos_token_ids': [sep_token_id],
              'mask_token_id': mask_token_id,
              # for adding od labels
              'add_od_labels': args.add_od_labels,
              'od_labels_start_posid': args.max_seq_a_length,
              # hyperparameters of beam search
              'max_length': args.max_gen_length,
              'num_beams': args.sc_beam_size,
              "temperature": args.temperature,
              "top_k": args.top_k,
              "top_p": args.top_p,
              "repetition_penalty": args.repetition_penalty,
              "length_penalty": args.length_penalty,
              "num_return_sequences": 1,
              "num_keep_best": 1,
              }

    def _ids_to_captions(all_ids):
        captions = []
        for ids in all_ids:
            c = tokenizer.decode(ids.tolist(), skip_special_tokens=True)
            captions.append(c)
        return captions

    if args.sc_baseline_type == 'greedy':
        model.eval()
        with torch.no_grad():
            greedy_res_raw, _ = model(**inputs)
            greedy_res_raw.squeeze_(1)  # batch_size * max_len
        greedy_res = _ids_to_captions(greedy_res_raw)
    else:
        greedy_res = None

    model.train()
    inputs['do_sample'] = True
    inputs['num_return_sequences'] = args.sc_train_sample_n
    sample_res_raw, sample_logprobs = model(**inputs)

    sample_res_raw = sample_res_raw.squeeze(1)
    sample_logprobs = sample_logprobs.squeeze(1)

    assert sample_logprobs.requires_grad == True
    assert sample_res_raw.requires_grad == False
    sample_res = _ids_to_captions(sample_res_raw)

    gt_res = [train_dataloader.dataset.get_captions_by_key(k) for k in img_keys]
    loss = scst_criterion(gt_res, greedy_res, sample_res, sample_logprobs)
    return loss




def evaluate(args, val_dataloader, model, tokenizer, model_dir):
    predict_file = get_predict_file(
        model_dir,
        val_dataloader.dataset.yaml_file,
        args
    )

    test(args, val_dataloader, model, tokenizer, predict_file)

    synchronize()

    evaluate_file = get_evaluate_file(predict_file, args.num_tags)
    if is_main_process():
        data = val_dataloader.dataset.yaml_file.split('/')[-2]
        if 'nocaps' not in data:
            caption_file = val_dataloader.dataset.get_caption_file_in_coco_format()
            result = evaluate_on_coco_caption(predict_file, caption_file, outfile=evaluate_file)
            logger.info(f'evaluation result: \n{yaml.dump(result)}')
            logger.info('evaluation result saved to {}'.format(evaluate_file))
        else:
            result = evaluate_on_nocaps('val', predict_file, data_dir=args.data_dir, evaluate_file=evaluate_file)
            logger.info(f'evaluation result: \n{yaml.dump(result)}')
            logger.info('evaluation result saved to {}'.format(evaluate_file))

    synchronize()
    return evaluate_file


def test(args, test_dataloader, model, tokenizer, predict_file):
    cls_token_id, sep_token_id, pad_token_id, mask_token_id, period_token_id = \
        tokenizer.convert_tokens_to_ids([tokenizer.cls_token, tokenizer.sep_token,
                                         tokenizer.pad_token, tokenizer.mask_token, '.'])
    world_size = get_world_size()

    if world_size == 1:
        cache_file = predict_file
    else:
        cache_file = op.splitext(predict_file)[0] + '_{}_{}'.format(get_rank(),
                                                                    world_size) + op.splitext(predict_file)[1]

    model.eval()
    inputs_param = {'is_decode': True,
                    'do_sample': False,
                    'bos_token_id': cls_token_id,
                    'pad_token_id': pad_token_id,
                    'eos_token_ids': [sep_token_id],
                    'mask_token_id': mask_token_id,
                    # for adding od labels
                    'add_od_labels': args.add_od_labels,
                    'od_labels_start_posid': args.max_seq_a_length,
                    # hyperparameters of beam search
                    'max_length': args.max_gen_length,
                    'num_beams': args.num_beams,
                    "temperature": args.temperature,
                    "top_k": args.top_k,
                    "top_p": args.top_p,
                    "repetition_penalty": args.repetition_penalty,
                    "length_penalty": args.length_penalty,
                    "num_return_sequences": args.num_return_sequences,
                    "num_keep_best": args.num_keep_best,
                    }
    if args.use_cbs:
        inputs_param.update({'use_cbs': True,
                             'min_constraints_to_satisfy': args.min_constraints_to_satisfy,
                             })
    def gen_rows():
        time_meter = 0

        with torch.no_grad():
            for step, (img_keys, batch) in tqdm(enumerate(test_dataloader)):
                batch = tuple(t.to(args.device) for t in batch)
                inputs = {
                    'input_ids': batch[0], 'attention_mask': batch[1],
                    'token_type_ids': batch[2], 'img_feats': batch[3],
                    'masked_pos': batch[4],
                }

                # Diff num tag
                if args.num_tags is not None:
                    start_idx = 20 + args.num_tags
                    inputs["attention_mask"][..., start_idx:50].fill_(0)
                if args.use_cbs:
                    inputs.update({
                        'fsm': batch[5],
                        'num_constraints': batch[6],
                    })
                inputs.update(inputs_param)
                tic = time.time()
                # captions, logprobs
                # output[0]: (Batch, N-Best(1), Seq) generated caption
                # output[1]: (Batch, 1) log probability score
                # output[2]: (Batch, Layer, Seq, Seq) Attention Maps
                synchronize()
                outputs = model(**inputs)
                time_meter += time.time() - tic
                all_caps = outputs[0]  # batch_size * num_keep_best * max_len
                all_confs = torch.exp(outputs[1])
                for img_key, caps, confs in zip(img_keys, all_caps, all_confs):
                    res = []
                    for cap, conf in zip(caps, confs):
                        cap = tokenizer.decode(cap.tolist(), skip_special_tokens=True)
                        res.append({'caption': cap, 'conf': conf.item()})
                    if isinstance(img_key, torch.Tensor):
                        img_key = img_key.item()
                    yield img_key, json.dumps(res)

        logger.info("Inference model computing time: {} seconds per batch".format(time_meter / (step + 1)))

    tsv_writer(gen_rows(), cache_file)
    synchronize()
    if world_size > 1 and is_main_process():
        cache_files = [op.splitext(predict_file)[0] + '_{}_{}'.format(i, world_size) + \
                       op.splitext(predict_file)[1] for i in range(world_size)]
        concat_tsv_files(cache_files, predict_file)
        delete_tsv_files(cache_files)
        reorder_tsv_keys(predict_file, test_dataloader.dataset.image_keys, predict_file)
    synchronize()

def main():
    # Setup global logger
    global logger

    # Get argumnets
    args = get_arguments()

    # Setup CUDA, GPU & distributed training
    local_rank = ensure_init_process_group()
    args.local_rank = local_rank
    args.num_gpus = get_world_size()
    args.distributed = args.num_gpus > 1
    args.device = torch.device('cuda')
    synchronize()

    # Setup Logger
    # logger = setup_logger("vlpretrain", output_dir, args.local_rank)
    logger = setup_loguru_logger("vlpretrain", args.output_dir, args.local_rank)
    logger.warning(f"Device: {args.device}, n_gpu: {args.num_gpus}")
    set_seed(args.seed, args.num_gpus)

    # Restore Argument
    args = restore_training_settings(args, logger)

    # WandB initialize
    if args.do_train:
        args.output_dir = wandb_setup(vars(args))

        if is_main_process():
            new_cfg_file = save_arg(args=vars(args),output_dir=args.output_dir)
            logger.info(f"new config saved to {new_cfg_file}")

    # Load pretrained model and tokenizer
    config_class, model_class, tokenizer_class = BertConfig, BertForImageCaptioning, BertTokenizer
    model, tokenizer, checkpoint = model_factory(logger, args, config_class, model_class, tokenizer_class)
    model.to(args.device)

    # Logging the parameters
    if is_main_process() is True:
        print_args = copy.deepcopy(vars(args))
        print_args["device"] = "cuda" if print_args["device"].type == "cuda" else "cpu"
        logger.info(f"Training/evaluation parameters : \n{yaml.dump(print_args)}")

    # Setup Data loader with arguments
    if args.do_train:
        train_dataloader = make_data_loader(args=args,
                                            yaml_file=args.train_yaml,
                                            tokenizer=tokenizer,
                                            logger=logger,
                                            is_distributed=args.distributed,
                                            is_train=True)
        val_dataloader = None
        test_dataloader = None
        if args.evaluate_during_training:
            val_dataloader = make_data_loader(args=args,
                                              yaml_file=args.val_yaml,
                                              tokenizer=tokenizer,
                                              logger=logger,
                                              is_distributed=args.distributed,
                                              is_train=False)
            if 'coco' or 'flickr' in args.test_yaml:
                test_dataloader = make_data_loader(args=args,
                                                   yaml_file=args.test_yaml,
                                                   logger=logger,
                                                   tokenizer=tokenizer,
                                                   is_distributed=args.distributed,
                                                   is_train=False)

        last_checkpoint = train(args, train_dataloader, val_dataloader, test_dataloader, model, tokenizer)

        # test the last checkpoint after training
        if args.do_test:
            logger.info("Evaluate on dataset: " + args.test_yaml)
            test_dataloader = make_data_loader(args=args,
                                               yaml_file=args.test_yaml,
                                               logger=logger,
                                               tokenizer=tokenizer,
                                               is_distributed=args.distributed,
                                               is_train=False)
            evaluate(args, test_dataloader, model, tokenizer, last_checkpoint)

    # inference and evaluation
    elif args.do_test or args.do_eval:
        logger.info("Evaluate on dataset: " + args.test_yaml)
        test_dataloader = make_data_loader(args=args,
                                           yaml_file=args.test_yaml,
                                           logger=logger,
                                           tokenizer=tokenizer,
                                           is_distributed=args.distributed, is_train=False)

        if not args.do_eval:
            predict_file = get_predict_file(checkpoint, test_dataloader.dataset.yaml_file, args)
            test(args, test_dataloader, model, tokenizer, predict_file)
            logger.info("Prediction results saved to: {}".format(predict_file))
        else:
            evaluate_file = evaluate(args, test_dataloader, model, tokenizer,checkpoint)
            logger.info("Evaluation results saved to: {}".format(evaluate_file))


if __name__ == "__main__":
    main()
