import argparse
import os
import random
import pathlib

import torch
from fairseq.models.roberta import SparseXLMRModel
from fairseq.modules import TransformerSentenceEncoderLayer

from transformers.modeling_bert import (
    BertIntermediate,
    BertLayer,
    BertOutput,
    BertSelfAttention,
    BertSelfOutput,
)
from transformers.modeling_xlm_roberta import XLMRobertaForMaskedLM, XLMRobertaConfig


SAMPLE_TEXT = "Hello world! cécé herlolip"


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Tool to convert sparse XLM-R fairseq model to huggingface transformers model",
    )
    # fmt: off
    parser.add_argument('--input', '-i', type=str, required=True, metavar='DIR',
                        help='Input checkpoint directory path.')
    parser.add_argument('--checkpoint', '-c', type=str, required=True, metavar='FILE',
                        help='Input checkpoint file name.')
    parser.add_argument('--data', '-d', type=str, required=True, metavar='DIR',
                        help='Data directory path.')
    parser.add_argument('--output', '-o', type=str, required=True, metavar='DIR',
                        help='Output checkpoint directory path.')
    parser.add_argument('--sparsity', '-s', type=float, required=True, metavar='FLOAT',
                        help='The target sparisity of the model.')
    parser.add_argument('--spm', type=str, default=None, metavar='FILE',
                        help='Sentencepiece model path.')
    # fmt: on
    args = parser.parse_args()
    print(args)

    # load pretrained model
    print("| Load {}".format(os.path.join(args.input, args.checkpoint)))
    spm_path = args.spm
    if spm_path is not None and os.path.exists(spm_path):
        pass
    elif os.path.exists(os.path.join(args.input, 'sentencepiece.bpe.model')):
        spm_path = os.path.join(args.input, 'sentencepiece.bpe.model')
    elif os.path.exists(os.path.join(args.output, 'sentencepiece.bpe.model')):
        spm_path = os.path.join(args.output, 'sentencepiece.bpe.model')
    else:
        spm_path = "https://huggingface.co/xlm-roberta-base/resolve/main/sentencepiece.bpe.model"
    roberta = SparseXLMRModel.from_pretrained(
        args.input,
        checkpoint_file=args.checkpoint,
        data_name_or_path=args.data,
        overrides={'sentencepiece_model': spm_path},
    )
    roberta.eval()

    # compute masks for all languages
    langs = [l.strip() for l in roberta.args.monolingual_langs.split(",")]
    lang2id = {lang: id for id, lang in enumerate(langs)}
    lang2id["zh"] = lang2id["zh-Hans"]  # for xtreme compatibility
    # TODO: handle 'yo' (Niger-Congo›Atlantic-Congo›Volta-Congo›Benue-Congo›Defoid›Yoruboid›Edekiri)

    lang_id = torch.arange(len(langs))  # all possible languages
    sparsity = torch.ones_like(lang_id) * args.sparsity
    rank_mask, head_masks, hidden_masks, extra = roberta.model.encoder.compute_language_masks(lang_id, target_sparsity=sparsity)
    idx = random.randrange(len(langs))
    print("| sparsity: count_nonzero / numel")
    if rank_mask is not None:
        print("| Rank mask    sparsity: overall {:.3f}, samples [{}]: {}".format(
            torch.count_nonzero(rank_mask).float().item() / float(torch.numel(rank_mask)),
            langs[idx], rank_mask[idx, 0, :].tolist())
        )
    if head_masks is not None:
        print("| Head masks   sparsity: overall {:.3f}, samples [{}]: {}".format(
            torch.count_nonzero(head_masks).float().item() / float(torch.numel(head_masks)),
            langs[idx], head_masks[idx, 0, :].tolist())
        )
    if hidden_masks is not None:
        print("| Hidden masks sparsity: overall {:.3f}, samples [{}]: {}".format(
            torch.count_nonzero(hidden_masks).float().item() / float(torch.numel(hidden_masks)),
            langs[idx], hidden_masks[idx, 0, :].tolist())
        )

    # convert fairseq model to huggingface model
    # adopted from https://github.com/huggingface/transformers/blob/129fdae04033fe4adfe013b734deaec6ec34ae2e/src/transformers/convert_roberta_original_pytorch_checkpoint_to_pytorch.py
    is_sparse = isinstance(roberta.model, SparseXLMRModel)
    roberta_sent_encoder = roberta.model.encoder.sentence_encoder
    config = XLMRobertaConfig(
        vocab_size=roberta_sent_encoder.embed_tokens.num_embeddings,
        hidden_size=roberta.args.encoder_embed_dim,
        num_hidden_layers=roberta.args.encoder_layers,
        num_attention_heads=roberta.args.encoder_attention_heads,
        intermediate_size=roberta.args.encoder_ffn_embed_dim,
        max_position_embeddings=514,
        type_vocab_size=1,
        layer_norm_eps=1e-5,  # PyTorch default used in fairseq
        sparse=is_sparse,
        n_langs=len(langs),
        lang2id=lang2id,
    )
    print("| Huggingface XLMRoberta config:", config)

    model = XLMRobertaForMaskedLM(config)
    model.eval()

    # Now let's copy all the weights.
    # Embeddings
    model.roberta.embeddings.word_embeddings.weight = roberta_sent_encoder.embed_tokens.weight
    model.roberta.embeddings.position_embeddings.weight = roberta_sent_encoder.embed_positions.weight
    model.roberta.embeddings.token_type_embeddings.weight.data = torch.zeros_like(
        model.roberta.embeddings.token_type_embeddings.weight
    )  # just zero them out b/c RoBERTa doesn't use them.
    model.roberta.embeddings.LayerNorm.weight = roberta_sent_encoder.emb_layer_norm.weight
    model.roberta.embeddings.LayerNorm.bias = roberta_sent_encoder.emb_layer_norm.bias
    if is_sparse:
        model.roberta.embeddings.rank_mask.data = rank_mask.squeeze(1)
        model.roberta.embeddings.projection.weight = roberta_sent_encoder.embed_tokens.projection.weight

    for i in range(config.num_hidden_layers):
        # Encoder: start of layer
        layer: BertLayer = model.roberta.encoder.layer[i]
        roberta_layer: TransformerSentenceEncoderLayer = roberta_sent_encoder.layers[i]

        # self attention
        self_attn: BertSelfAttention = layer.attention.self
        assert (
                roberta_layer.self_attn.k_proj.weight.data.shape
                == roberta_layer.self_attn.q_proj.weight.data.shape
                == roberta_layer.self_attn.v_proj.weight.data.shape
                == torch.Size((config.hidden_size, config.hidden_size))
        )

        self_attn.query.weight.data = roberta_layer.self_attn.q_proj.weight
        self_attn.query.bias.data = roberta_layer.self_attn.q_proj.bias
        self_attn.key.weight.data = roberta_layer.self_attn.k_proj.weight
        self_attn.key.bias.data = roberta_layer.self_attn.k_proj.bias
        self_attn.value.weight.data = roberta_layer.self_attn.v_proj.weight
        self_attn.value.bias.data = roberta_layer.self_attn.v_proj.bias

        if is_sparse:
            self_attn.head_mask.data = head_masks[:, i, :]

        # self-attention output
        self_output: BertSelfOutput = layer.attention.output
        assert self_output.dense.weight.shape == roberta_layer.self_attn.out_proj.weight.shape
        self_output.dense.weight = roberta_layer.self_attn.out_proj.weight
        self_output.dense.bias = roberta_layer.self_attn.out_proj.bias
        self_output.LayerNorm.weight = roberta_layer.self_attn_layer_norm.weight
        self_output.LayerNorm.bias = roberta_layer.self_attn_layer_norm.bias

        # intermediate
        intermediate: BertIntermediate = layer.intermediate
        assert intermediate.dense.weight.shape == roberta_layer.fc1.weight.shape
        intermediate.dense.weight = roberta_layer.fc1.weight
        intermediate.dense.bias = roberta_layer.fc1.bias

        if is_sparse:
            layer.hidden_mask.data = hidden_masks[:, i, :]

        # output
        bert_output: BertOutput = layer.output
        assert bert_output.dense.weight.shape == roberta_layer.fc2.weight.shape
        bert_output.dense.weight = roberta_layer.fc2.weight
        bert_output.dense.bias = roberta_layer.fc2.bias
        bert_output.LayerNorm.weight = roberta_layer.final_layer_norm.weight
        bert_output.LayerNorm.bias = roberta_layer.final_layer_norm.bias
        # end of layer

    # LM Head
    model.lm_head.dense.weight = roberta.model.encoder.lm_head.dense.weight
    model.lm_head.dense.bias = roberta.model.encoder.lm_head.dense.bias
    model.lm_head.layer_norm.weight = roberta.model.encoder.lm_head.layer_norm.weight
    model.lm_head.layer_norm.bias = roberta.model.encoder.lm_head.layer_norm.bias
    model.lm_head.decoder.weight = roberta.model.encoder.lm_head.weight
    model.lm_head.bias = roberta.model.encoder.lm_head.bias

    if is_sparse:
        # note the transpose here
        model.lm_head.projection.weight.data = roberta.model.encoder.lm_head.proj_weight.T
        model.lm_head.rank_mask.data = rank_mask.squeeze(1)

    # Let's check that we get the same results.
    input_ids: torch.Tensor = roberta.encode(SAMPLE_TEXT).unsqueeze(0)  # batch of size 1

    huggingface_output = model(input_ids, langs=torch.tensor([[langs.index("en")] * input_ids.size(1)]))[0]
    fairseq_output = roberta.model(input_ids, lang_id=torch.tensor([langs.index("en")]), target_sparsity=sparsity)[0]
    max_absolute_diff = torch.max(torch.abs(huggingface_output - fairseq_output)).item()
    print(f"| max_absolute_diff = {max_absolute_diff}")  # ~ 1e-7
    success = torch.allclose(huggingface_output, fairseq_output, atol=1e-3)
    print("| Do both models output the same tensors?", "🔥" if success else "💩")
    if not success:
        raise Exception("Something went wRoNg")

    pathlib.Path(args.output).mkdir(parents=True, exist_ok=True)
    print(f"| Saving model to {args.output}")
    model.save_pretrained(args.output)
