import argparse
import os
import random
import math
import logging
import sys

import torch
from fairseq import (
    tasks,
    utils,
)
from fairseq.models.roberta import RobertaModel
from fairseq.trainer import Trainer
from fairseq.logging import progress_bar


SAMPLE_TEXT = "Hello world! cécé herlolip"

logging.basicConfig(
    format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
    level=os.environ.get("LOGLEVEL", "INFO").upper(),
    stream=sys.stdout,
)
logger = logging.getLogger("induce_sparse_model_from_dense")


def compute_importance(args, model):
    rank_importance = torch.zeros(model.encoder.rank_mask.size(), device=model.encoder.rank_mask.device) if args.embed_factorize else None
    head_importance = torch.zeros(model.encoder.head_masks.size(), device=model.encoder.head_masks.device)
    hidden_importance = torch.zeros(model.encoder.hidden_masks.size(), device=model.encoder.hidden_masks.device)

    # Setup task, e.g., translation, language modeling, etc.
    task = tasks.setup_task(args)

    # Load valid dataset
    valid_subsets = args.valid_subset.split(",")
    for valid_sub_split in valid_subsets:
        task.load_dataset(valid_sub_split, combine=False, epoch=1)

    # Build criterion and trainer
    criterion = task.build_criterion(args)
    trainer = Trainer(args, task, model, criterion)

    trainer.begin_valid_epoch(1)
    for subset in valid_subsets:
        logger.info('begin gradient accumulation on "{}" subset'.format(subset))

        # Initialize data iterator
        itr = trainer.get_valid_iterator(subset).next_epoch_itr(shuffle=False)
        if getattr(args, "tpu", False):
            itr = utils.tpu_data_loader(itr)
        progress = progress_bar.progress_bar(
            itr,
            log_format=args.log_format,
            log_interval=args.log_interval,
            epoch=1,
            prefix=f"accumulate on '{subset}' subset",
            default_log_format=("tqdm" if not args.no_progress_bar else "simple"),
        )

        for sample in progress:
            trainer.zero_grad()

            # forward and backward pass
            sample = trainer._prepare_sample(sample)
            if sample is None:
                # when sample is None, run forward/backward on a dummy batch
                # and ignore the resulting gradients
                sample = trainer._prepare_sample(trainer._dummy_batch)
                is_dummy_batch = True
            else:
                if trainer._dummy_batch == "DUMMY":
                    trainer._dummy_batch = sample
                is_dummy_batch = False

            # forward and backward
            loss, _, _ = trainer.task.train_step(
                sample=sample,
                model=trainer.model,
                criterion=trainer.criterion,
                optimizer=trainer.optimizer,
                update_num=trainer.get_num_updates(),
                ignore_grad=is_dummy_batch,
            )

            if torch.isnan(model.encoder.rank_mask.grad).any():
                logger.warning("nan detected on rank_mask.grad")
                continue
            if torch.isinf(model.encoder.rank_mask.grad).any():
                logger.warning("inf detected on rank_mask.grad")
                continue
            if torch.isnan(model.encoder.head_masks.grad).any():
                logger.warning("nan detected on head_masks.grad")
                continue
            if torch.isinf(model.encoder.head_masks.grad).any():
                logger.warning("inf detected on head_masks.grad")
                continue
            if torch.isnan(model.encoder.hidden_masks.grad).any():
                logger.warning("nan detected on hidden_masks.grad")
                continue
            if torch.isinf(model.encoder.hidden_masks.grad).any():
                logger.warning("inf detected on hidden_masks.grad")
                continue

            # accumulate importance, following https://arxiv.org/abs/1905.10650
            if rank_importance is not None:
                rank_importance += model.encoder.rank_mask.grad.abs().detach()
            head_importance += model.encoder.head_masks.grad.abs().detach()
            hidden_importance += model.encoder.hidden_masks.grad.abs().detach()

            # emptying the CUDA cache after the first step can
            # reduce the chance of OOM
            if trainer.cuda and trainer.get_num_updates() == 0:
                torch.cuda.empty_cache()

    return rank_importance, head_importance, hidden_importance


if __name__ == "__main__":
    """
    python scripts/induce_sparse_model_from_dense.py -i . -c model.pt -d slice.0 -o model.pt.sparse0.5 -s 0.5
    """
    parser = argparse.ArgumentParser(
        description="Tool to convert standard XLM-R model to a sparse model",
    )
    # fmt: off
    parser.add_argument('--input', '-i', type=str, required=True, metavar='DIR',
                        help='Input checkpoint directory path.')
    parser.add_argument('--checkpoint', '-c', type=str, required=True, metavar='FILE',
                        help='Input checkpoint file name.')
    parser.add_argument('--data', '-d', type=str, required=True, metavar='DIR',
                        help='Data directory path.')
    parser.add_argument('--output', '-o', type=str, required=True, metavar='FILE',
                        help='Output checkpoint file name.')
    parser.add_argument("--sparsity", "-s", type=float, required=True, metavar='D',
                        help="The portion of parameters that you want to retain.")
    parser.add_argument(
        "--lang-agnostic",
        action="store_true",
        default=False,
        help="if set, then the sparsity mask will be shared among languages",
    )

    # group = parser.add_mutually_exclusive_group()
    # group.add_argument("--threshold", "-t", type=float, default=None,
    #                    help="use a pre-defined threshold to compute the 0-1 mask from the importance scores.")
    # group.add_argument("--top-k", "-k", type=float, default=None,
    #                    help="use a pre-defined percentage to compute the 0-1 mask from the importance scores.")
    # fmt: on
    args = parser.parse_args()
    print(args)

    # load pretrained model
    logger.info("Update {}".format(os.path.join(args.input, args.checkpoint)))
    ckpt = torch.load(os.path.join(args.input, args.checkpoint), map_location="cpu")

    # add missing arguments (default values)
    ckpt["args"].pipeline_model_parallel = False
    ckpt["args"].freq_weighted_replacement = False
    ckpt["args"].distributed_wrapper = 'DDP'
    ckpt["args"].broadcast_buffers = False
    ckpt["args"].data_buffer_size = 10
    ckpt["args"].bf16 = False
    ckpt["args"].model_parallel_size = 1
    ckpt["args"].zero_sharding = 'none'

    # add sparse related arguments
    ckpt["args"].arch = "sparse_xlmr_base"
    ckpt["args"].criterion = "sparse_masked_lm"
    ckpt["args"].sparsity_weight = 0.
    ckpt["args"].one4all = None
    ckpt["args"].diagonal_weight = 0.
    ckpt["args"].lang2group = None
    ckpt["args"].non_parameterize = True
    ckpt["args"].lang_agnostic = args.lang_agnostic
    ckpt["args"].embed_factorize = True
    ckpt["args"].clamp = False
    ckpt["args"].monolingual_langs = "af,am,ar,as,az,be,bg,bn,bn_rom,br,bs,ca,cs,cy,da,de,el,en,eo,es,et,eu,fa,fi,fr,fy,ga,gd,gl,gu,ha,he,hi,hi_rom,hr,hu,hy,id,is,it,ja,jv,ka,kk,km,kn,ko,ku,ky,la,lo,lt,lv,mg,mk,ml,mn,mr,ms,my_zaw,my,ne,nl,no,om,or,pa,pl,ps,pt,ro,ru,sa,sd,si,sk,sl,so,sq,sr,su,sv,sw,ta,ta_rom,te,te_rom,th,tl,tr,ug,uk,ur,ur_rom,uz,vi,xh,yi,zh-Hans,zh-Hant"

    # update arguments
    ckpt["args"].log_format = "tqdm"
    ckpt["args"].batch_size_valid = 1  # must be 1 to accumuate absolute gradient correctly
    ckpt["args"].distributed_world_size = 1  # overwrite this to prevent launching a distributed model
    ckpt["args"].fp16 = False  # overwrite this to prevent launching a fp16 model (avoid NaN or Inf)
    ckpt["args"].valid_subset = "valid"  # overwrite this to prevent loading non-existing data

    # save updated checkpoint for reload
    assert not os.path.exists(os.path.join(args.input, "temp.pt")), "{} already exists, please check".format(os.path.join(args.input, "temp.pt"))
    torch.save(ckpt, os.path.join(args.input, "temp.pt"))
    del ckpt

    # reload checkpoint
    roberta = RobertaModel.from_pretrained(
        args.input,
        checkpoint_file="temp.pt",
        data_name_or_path=args.data,
    )
    roberta.eval()  # TODO: eval() has no effect here on dropout
    roberta = roberta.cuda() if torch.cuda.is_available() else roberta
    print(roberta.model)

    # compute masks for all languages
    rank_importance, head_importance, hidden_importance = compute_importance(roberta.args, roberta.model)
    torch.save([rank_importance, head_importance, hidden_importance], "importance.pt")

    rank_weight = torch.ones_like(rank_importance)
    head_weight = torch.ones_like(head_importance) * 64 * 4
    hidden_weight = torch.ones_like(hidden_importance) * 2
    if rank_importance is not None:
        _flatten = torch.cat([rank_importance, head_importance, hidden_importance], dim=-1)
        _weighted_flatten = torch.cat([rank_importance, head_importance.repeat(1, 64 * 4), hidden_importance.repeat(1, 2)], dim=-1)
        _weight = torch.cat([rank_weight, head_weight, hidden_weight], dim=-1)
    else:
        _flatten = torch.cat([head_importance, hidden_importance], dim=-1)
        _weighted_flatten = torch.cat([head_importance.repeat(1, 64 * 4), hidden_importance.repeat(1, 2)], dim=-1)
        _weight = torch.cat([head_weight, hidden_weight], dim=-1)
    threshold = torch.quantile(_weighted_flatten, 1 - args.sparsity, dim=-1, keepdim=True)
    mask = (_flatten > threshold).float()
    sparsity = (mask * _weight).sum(dim=-1) / _weight.sum(dim=-1)

    if rank_importance is not None:
        rank_mask = mask[:, :rank_importance.size(1)].view(roberta.model.encoder.rank_mask.size())
        mask = mask[:, rank_importance.size(1):]
        roberta.model.encoder.rank_mask.data = rank_mask.type_as(roberta.model.encoder.rank_mask)
    else:
        rank_mask = None
    head_masks = mask[:, :head_importance.size(1)].view(roberta.model.encoder.head_masks.size())
    hidden_masks = mask[:, head_importance.size(1):].view(roberta.model.encoder.hidden_masks.size())

    roberta.model.encoder.head_masks.data = head_masks.type_as(roberta.model.encoder.head_masks)
    roberta.model.encoder.hidden_masks.data = hidden_masks.type_as(roberta.model.encoder.hidden_masks)

    # nothing but printing
    langs = [l.strip() for l in roberta.args.monolingual_langs.split(",")]
    idx = random.randrange(len(langs))
    logger.info("Per language sparsity: {}".format({lang: ratio for lang, ratio in zip([l.strip() for l in roberta.args.monolingual_langs.split(",")], sparsity.cpu().numpy().tolist())}))

    head_masks = head_masks.view(-1, roberta.args.encoder_layers, roberta.args.encoder_attention_heads)
    hidden_masks = hidden_masks.view(-1, roberta.args.encoder_layers, roberta.args.encoder_ffn_embed_dim)
    if rank_mask is not None:
        logger.info("Rank mask    sparsity: overall {:.3f}, samples [{}]: {}".format(
            torch.count_nonzero(rank_mask).float().item() / float(torch.numel(rank_mask)),
            langs[idx], rank_mask[idx, :].tolist())
        )
    if head_masks is not None:
        logger.info("Head masks   sparsity: overall {:.3f}, samples [{}]: {}".format(
            torch.count_nonzero(head_masks).float().item() / float(torch.numel(head_masks)),
            langs[idx], head_masks[idx, 0, :].tolist())
        )
    if hidden_masks is not None:
        logger.info("Hidden masks sparsity: overall {:.3f}, samples [{}]: {}".format(
            torch.count_nonzero(hidden_masks).float().item() / float(torch.numel(hidden_masks)),
            langs[idx], hidden_masks[idx, 0, :].tolist())
        )

    # dummy run to see whether any error will be thrown
    input_ids: torch.Tensor = roberta.encode(SAMPLE_TEXT).unsqueeze(0)  # batch of size 1
    input_ids = input_ids.cuda() if torch.cuda.is_available() else input_ids
    fairseq_output = roberta.model(input_ids, lang_id=torch.tensor([langs.index("en")], device=input_ids.device))[0]

    # save model by replacing state_dict() in an existing checkpoint
    ckpt = torch.load(os.path.join(args.input, "temp.pt"), map_location="cpu")
    # ckpt["args"] = roberta.args
    ckpt["model"] = roberta.model.state_dict()
    torch.save(ckpt, args.output)
    logger.info("Save checkpoint to {}".format(args.output))

    # delete temp file
    os.remove(os.path.join(args.input, "temp.pt"))
