# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
RoBERTa: A Robustly Optimized BERT Pretraining Approach.
"""

import logging
import math
import collections
import os

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from fairseq import utils
from fairseq.models import (
    FairseqEncoder,
    register_model,
    register_model_architecture,
)
from fairseq.modules import LayerNorm, TransformerSentenceEncoder

from .model_xlmr import XLMRModel
from .hub_interface import RobertaHubInterface


logger = logging.getLogger(__name__)

_INIT_IMPL = collections.defaultdict(lambda: lambda model, **kwargs: None)
_PRE_IMPL = {}
_SPARSE_IMPL = {}
_POST_IMPL = {}


def register_to(name: str, mapping: dict):
    def wrapper(fn):
        mapping[name] = fn
        return fn
    return wrapper


def inverse_sigmoid(x):
    if not isinstance(x, torch.Tensor):
        x = torch.tensor(x)
    return torch.log(x / (1 - x))


@register_model("sparse_xlmr")
class SparseXLMRModel(XLMRModel):

    @classmethod
    def from_pretrained(
        cls,
        model_name_or_path,
        checkpoint_file="model.pt",
        data_name_or_path=".",
        bpe="sentencepiece",
        overrides={},
        **kwargs
    ):
        from fairseq import hub_utils

        x = hub_utils.from_pretrained(
            model_name_or_path,
            checkpoint_file,
            data_name_or_path,
            archive_map=cls.hub_models(),
            bpe=bpe,
            load_checkpoint_heads=True,
            **kwargs,
        )
        cls.upgrade_args(x["args"])
        for k, v in overrides.items():
            setattr(x["args"], k, v)

        logger.info(x["args"])
        return RobertaHubInterface(x["args"], x["task"], x["models"][0])

    def __init__(self, args, encoder):
        super().__init__(args, encoder)

        if not args.non_parameterize:  # dynamic sparsification
            _INIT_IMPL[args.sparse_impl](self, **eval(self.args.init_args))
        else:  # static sparsification
            nn.init.constant_(self.encoder.rank_mask, 1.)
            nn.init.constant_(self.encoder.head_masks, 1.)
            nn.init.constant_(self.encoder.hidden_masks, 1.)

    @staticmethod
    def add_args(parser):
        """Add model-specific arguments to the parser."""
        XLMRModel.add_args(parser)
        parser.add_argument(
            "--embed-factorize",
            action="store_true",
            default=False,
            help="if set, then factorize the embeddings (and the output projection)",
        )
        parser.add_argument(
            "--non-parameterize",
            action="store_true",
            default=False,
            help="if set, then the sparsity mask will simply be a 0-1 matrix",
        )
        parser.add_argument(
            "--clamp",
            action="store_true",
            default=False,
            help="if set, then the non-factorized sparsity mask will be clamped",
        )
        parser.add_argument(
            "--lang-agnostic",
            action="store_true",
            default=False,
            help="if set, then the sparsity mask will be shared among languages",
        )
        parser.add_argument(
            "--init-args",
            type=str,
            metavar="DICT",
            default="{}",
            help="the arguments of initialization for sparsity control matrix, e.g., \"{'sparsity': 0.5}\".",
        )
        parser.add_argument(
            "--pre-impl",
            default="no_op",
            choices=_PRE_IMPL.keys(),
            help="the implementation of pre-processing of sparsity, e.g., no_op, etc.",
        )
        parser.add_argument(
            "--pre-args",
            type=str,
            metavar="DICT",
            default="{}",
            help="the arguments of pre-processing.",
        )
        parser.add_argument(
            "--sparse-impl",
            default="hard_concrete",
            choices=_SPARSE_IMPL.keys(),
            help="the implementation of sparsity, e.g., hard_concrete, etc.",
        )
        parser.add_argument(
            "--sparse-args",
            type=str,
            metavar="DICT",
            default="{}",
            help="the arguments of sparsity, e.g., \"{'temperature': 1.}\".",
        )
        parser.add_argument(
            "--post-impl",
            default="no_op",
            choices=_POST_IMPL.keys(),
            help="the implementation of post-processing of sparsity, e.g., no_op, etc.",
        )
        parser.add_argument(
            "--post-args",
            type=str,
            metavar="DICT",
            default="{}",
            help="the arguments of post-processing, e.g., \"{'constant': 1.}\".",
        )

    @classmethod
    def build_model(cls, args, task):
        """Build a new model instance."""

        # make sure all arguments are present
        base_architecture(args)

        if not hasattr(args, "max_positions"):
            args.max_positions = args.tokens_per_sample

        encoder = SparseEncoder(args, task.source_dictionary)
        return cls(args, encoder)


class SparseLMHead(nn.Module):
    """Head for masked language modeling."""

    def __init__(self, embed_dim, output_dim, activation_fn, weight=None, embed_factorize=False, proj_weight=None):
        super().__init__()
        self.dense = nn.Linear(embed_dim, embed_dim)
        self.activation_fn = utils.get_activation_fn(activation_fn)
        self.layer_norm = LayerNorm(embed_dim)
        self.embed_factorize = embed_factorize

        if weight is None:
            weight = nn.Linear(embed_dim, output_dim, bias=False).weight
        self.weight = weight
        self.bias = nn.Parameter(torch.zeros(output_dim))

        if self.embed_factorize:
            if proj_weight is None:
                # note that we will transpose it in forward, so in_features and
                # out_features are reversed here. This fact is important to the
                # initialization when they are different
                proj_weight = nn.Linear(embed_dim, embed_dim, bias=False).weight
            self.proj_weight = proj_weight

    def forward(self, features, masked_tokens=None, mask=None, **kwargs):
        # Only project the masked tokens while training,
        # saves both memory and computation
        if masked_tokens is not None:
            if mask is not None:
                mask = mask.expand(-1, features.size(1), -1)[masked_tokens, :]
            features = features[masked_tokens, :]

        x = self.dense(features)
        x = self.activation_fn(x)
        x = self.layer_norm(x)
        if self.embed_factorize:
            # the transpose is to match the shared embedding
            x = F.linear(x, self.proj_weight.T)
        if mask is not None:
            x = x * mask
        # project back to size of vocabulary with bias
        x = F.linear(x, self.weight) + self.bias
        return x


class SparseEncoder(FairseqEncoder):
    """RoBERTa encoder."""

    def __init__(self, args, dictionary):
        super().__init__(dictionary)
        self.args = args
        self.embed_factorize = args.embed_factorize

        if args.encoder_layers_to_keep:
            args.encoder_layers = len(args.encoder_layers_to_keep.split(","))

        self.sentence_encoder = TransformerSentenceEncoder(
            padding_idx=dictionary.pad(),
            vocab_size=len(dictionary),
            num_encoder_layers=args.encoder_layers,
            embedding_dim=args.encoder_embed_dim,
            ffn_embedding_dim=args.encoder_ffn_embed_dim,
            num_attention_heads=args.encoder_attention_heads,
            dropout=args.dropout,
            attention_dropout=args.attention_dropout,
            activation_dropout=args.activation_dropout,
            layerdrop=args.encoder_layerdrop,
            max_seq_len=args.max_positions,
            num_segments=0,
            encoder_normalize_before=True,
            apply_bert_init=True,
            activation_fn=args.activation_fn,
            q_noise=args.quant_noise_pq,
            qn_block_size=args.quant_noise_pq_block_size,
            embed_factorize=self.embed_factorize,
        )
        args.untie_weights_roberta = getattr(args, "untie_weights_roberta", False)

        self.lm_head = SparseLMHead(
            embed_dim=args.encoder_embed_dim,
            output_dim=len(dictionary),
            activation_fn=args.activation_fn,
            weight=(
                self.sentence_encoder.embed_tokens.weight
                if not args.untie_weights_roberta
                else None
            ),
            embed_factorize=self.embed_factorize,
            proj_weight=(
                self.sentence_encoder.embed_tokens.projection.weight
                if not args.untie_weights_roberta and self.embed_factorize
                else None
            ),
        )

        # perform conditional sparsity for different languages
        langs = [l.strip() for l in args.monolingual_langs.split(",")]
        self.num_langs = len(langs)
        num_langs = self.num_langs if not self.args.lang_agnostic else 1
        if not self.args.non_parameterize:  # dynamic sparsification
            # language + position factors
            self.rank_weight = nn.Embedding(num_langs, args.encoder_embed_dim)
            self.head_weight = nn.Embedding(num_langs, args.encoder_layers * args.encoder_attention_heads)
            self.hidden_weight = nn.Embedding(num_langs, args.encoder_layers * args.encoder_ffn_embed_dim)
            # language + position + target factors
            self.rank_target = nn.Embedding(num_langs, args.encoder_embed_dim)
            self.head_target = nn.Embedding(num_langs, args.encoder_layers * args.encoder_attention_heads)
            self.hidden_target = nn.Embedding(num_langs, args.encoder_layers * args.encoder_ffn_embed_dim)

            self.rank_weight.weight.requires_grad = not args.clamp
            self.head_weight.weight.requires_grad = not args.clamp
            self.hidden_weight.weight.requires_grad = not args.clamp
            self.rank_target.weight.requires_grad = not args.clamp
            self.head_target.weight.requires_grad = not args.clamp
            self.hidden_target.weight.requires_grad = not args.clamp
        else:  # static sparsification (also the utility for computing 1st order based importance)
            self.rank_mask = nn.Parameter(torch.ones(num_langs, args.encoder_embed_dim), requires_grad=not args.clamp)
            self.head_masks = nn.Parameter(torch.ones(num_langs, args.encoder_layers * args.encoder_attention_heads), requires_grad=not args.clamp)
            self.hidden_masks = nn.Parameter(torch.ones(num_langs, args.encoder_layers * args.encoder_ffn_embed_dim), requires_grad=not args.clamp)

    def forward(
        self,
        src_tokens,
        features_only=False,
        return_all_hiddens=False,
        masked_tokens=None,
        **unused
    ):
        """
        Args:
            src_tokens (LongTensor): input tokens of shape `(batch, src_len)`
            features_only (bool, optional): skip LM head and just return
                features. If True, the output will be of shape
                `(batch, src_len, embed_dim)`.
            return_all_hiddens (bool, optional): also return all of the
                intermediate hidden states (default: False).

        Returns:
            tuple:
                - the LM output of shape `(batch, src_len, vocab)`
                - a dictionary of additional data, where 'inner_states'
                  is a list of hidden states. Note that the hidden
                  states have shape `(src_len, batch, vocab)`.
        """
        x, extra = self.extract_features(
            src_tokens, return_all_hiddens=return_all_hiddens,
            lang_id=unused.get("lang_id", None),
            target_sparsity=unused.get("target_sparsity", None),
        )
        if not features_only:
            x = self.output_layer(
                x, masked_tokens=masked_tokens,
                mask=extra.get("rank_mask", None),
            )
        return x, extra

    def extract_features(self, src_tokens, return_all_hiddens=False, **kwargs):
        # compute the language-dependent architectural masks
        rank_mask, head_masks, hidden_masks, extra = self.compute_language_masks(
            kwargs.get("lang_id", None), target_sparsity=kwargs.get("target_sparsity", None)
        )
        inner_states, _ = self.sentence_encoder(
            src_tokens,
            last_state_only=not return_all_hiddens,
            token_embeddings=kwargs.get("token_embeddings", None),
            rank_mask=rank_mask,
            head_masks=head_masks,
            hidden_masks=hidden_masks,
        )
        features = inner_states[-1].transpose(0, 1)  # T x B x C -> B x T x C
        return features, {"inner_states": inner_states if return_all_hiddens else None, **extra}

    def compute_language_masks(self, lang_id, target_sparsity=None):
        """
        Args:
            lang_id (LongTensor): language id of shape `(batch,)`.
            target_sparsity (FloatTensor): target sparsity of shape `(batch,)`.

        Returns:
            tuple:
                - FloatTensor: a tensor with real values, used for masking ranks,
                  of shape `(batch, 1, num_rank)`.
                - FloatTensor: a tensor with real values, used for masking heads,
                  of shape `(batch, layer, num_head)`.
                - FloatTensor: a tensor with real values, used for masking neurons,
                  of shape `(batch, layer, num_hidden)`.
                - Dict: a dictionary that contains the sparsity for each mask.
        """
        extra = {}  # pack sparsity for loss calculation
        if not self.args.lang_agnostic:
            if not self.args.non_parameterize:  # Our L0
                ts = target_sparsity.unsqueeze(-1).unsqueeze(-1).type_as(self.head_target.weight)
                extra["head_sparsity"], head_masks = self._sparsity_mask(
                    self.head_weight(lang_id).view(-1, self.args.encoder_layers, self.args.encoder_attention_heads) +
                    ts * self.head_target(lang_id).view(-1, self.args.encoder_layers, self.args.encoder_attention_heads)
                )
                extra["hidden_sparsity"], hidden_masks = self._sparsity_mask(
                    self.hidden_weight(lang_id).view(-1, self.args.encoder_layers, self.args.encoder_ffn_embed_dim) +
                    ts * self.hidden_target(lang_id).view(-1, self.args.encoder_layers, self.args.encoder_ffn_embed_dim)
                )
                if self.embed_factorize:
                    extra["rank_sparsity"], rank_mask = self._sparsity_mask(
                        self.rank_weight(lang_id).unsqueeze(1) +
                        ts * self.rank_target(lang_id).unsqueeze(1)
                    )  # a hack to replace num_layers by 1 to account for src_len
                    # TODO: a separate rank_mask if embeddings and lm_head are not shared
                else:
                    rank_mask = None
            else:  # Our Gradient-based Pruning
                head_masks = F.embedding(lang_id, self.head_masks).view(-1, self.args.encoder_layers, self.args.encoder_attention_heads)
                hidden_masks = F.embedding(lang_id, self.hidden_masks).view(-1, self.args.encoder_layers, self.args.encoder_ffn_embed_dim)
                extra["head_sparsity"] = head_masks
                extra["hidden_sparsity"] = hidden_masks
                if self.embed_factorize:
                    rank_mask = F.embedding(lang_id, self.rank_mask).unsqueeze(1)
                    extra["rank_sparsity"] = rank_mask
                else:
                    rank_mask = None
        else:
            if not self.args.non_parameterize:  # Standard L0
                ts = target_sparsity.unsqueeze(-1).unsqueeze(-1).type_as(self.head_target.weight)
                extra["head_sparsity"], head_masks = self._sparsity_mask(
                    self.head_weight.weight.expand(lang_id.size(0), -1).view(-1, self.args.encoder_layers, self.args.encoder_attention_heads) +
                    ts * self.head_target.weight.view(-1, self.args.encoder_layers, self.args.encoder_attention_heads)
                )
                extra["hidden_sparsity"], hidden_masks = self._sparsity_mask(
                    self.hidden_weight.weight.expand(lang_id.size(0), -1).view(-1, self.args.encoder_layers, self.args.encoder_ffn_embed_dim) +
                    ts * self.hidden_target.weight.view(-1, self.args.encoder_layers, self.args.encoder_ffn_embed_dim)
                )
                if self.embed_factorize:
                    extra["rank_sparsity"], rank_mask = self._sparsity_mask(
                        self.rank_weight.weight.unsqueeze(0).expand(lang_id.size(0), -1, -1) +
                        ts * self.rank_target.weight.unsqueeze(0)
                    )
                    # TODO: a separate rank_mask if embeddings and lm_head are not shared
                else:
                    rank_mask = None
            else:  # Standard Gradient-based Pruning
                head_masks = self.head_masks.expand(lang_id.size(0), -1).view(-1, self.args.encoder_layers, self.args.encoder_attention_heads)
                hidden_masks = self.hidden_masks.expand(lang_id.size(0), -1).view(-1, self.args.encoder_layers, self.args.encoder_ffn_embed_dim)
                extra["head_sparsity"] = head_masks
                extra["hidden_sparsity"] = hidden_masks
                if self.embed_factorize:
                    rank_mask = self.rank_mask.unsqueeze(0).expand(lang_id.size(0), -1, -1)
                    extra["rank_sparsity"] = rank_mask
                else:
                    rank_mask = None
        extra["rank_mask"] = rank_mask  # must
        extra["head_masks"] = head_masks
        extra["hidden_masks"] = hidden_masks
        return rank_mask, head_masks, hidden_masks, extra

    def _sparsity_mask(self, x):
        """
        Args:
            x (FloatTensor): parameters of shape `(batch, layer, num_head/_hidden)`.

        Returns:
            tuple:
                - FloatTensor: a tensor with values ranging from 0 to 1, of shape
                  `(batch, layer, num_head/_hidden)`, used to calculate sparsity
                - FloatTensor: a tensor with real values, used for actual masking,
                  of shape `(batch, layer, num_head/_hidden)`.
        """
        pre_mask = _PRE_IMPL[self.args.pre_impl](self, x, **eval(self.args.pre_args))
        sparsity, mask = _SPARSE_IMPL[self.args.sparse_impl](self, pre_mask, **eval(self.args.sparse_args))
        post_mask = _POST_IMPL[self.args.post_impl](self, mask, **eval(self.args.post_args))
        return sparsity, post_mask

    @register_to("no_op", _PRE_IMPL)
    def pre_no_op(self, x, **kwargs):
        return x

    @register_to("local_baseline", _PRE_IMPL)
    def local_baseline(self, x, **kwargs):
        return x - torch.mean(x, dim=-1, keepdim=True)

    @register_to("global_baseline", _PRE_IMPL)
    def global_baseline(self, x, **kwargs):
        return x - torch.mean(x, dim=[1, 2], keepdim=True)

    @register_to("sigmoid", _SPARSE_IMPL)
    def sigmoid(self, x, **kwargs):
        y = torch.sigmoid(x)
        return y, y

    @register_to("hard_tanh", _SPARSE_IMPL)
    def hard_tanh(self, x, **kwargs):
        y = torch.relu(torch.tanh(x))
        return y, y

    # "Learning Sparse Neural Networks through L0 Regularization" (Louizos et al., 2018)
    @register_to("hard_concrete", _SPARSE_IMPL)
    def concrete(self, x, temperature=2./3., zeta=1.1, gamma=-0.1, **kwargs):
        sparsity = torch.sigmoid(x - temperature * math.log(-gamma / zeta))  # L0
        if self.training and not self.args.clamp:
            u = torch.rand(x.size(), dtype=x.dtype, device=x.device)
            x = (torch.log(u) - torch.log(1. - u) + x) / temperature
        return sparsity, torch.clamp(torch.sigmoid(x) * (zeta - gamma) + gamma, min=0, max=1)

    @register_to("no_op", _POST_IMPL)
    def post_no_op(self, x, **kwargs):
        return x

    @register_to("auto", _POST_IMPL)
    def auto(self, x, truncation=1e3, **kwargs):
        # scale up the mask by the inverse of sparsity, 1e-4 is to avoid dividing 0.
        # But it has a high risk of overflowing in FP16 when the sparsity is high,
        # so we truncate the automatic determined scale.
        scale = x.size(-1) / (torch.sum(x, dim=-1, keepdim=True) + 1e-4)
        if truncation is not None:
            scale = torch.clamp(scale, max=truncation)
        return x * scale

    @register_to("fix", _POST_IMPL)
    def fix(self, x, constant=2., **kwargs):
        return x * constant

    def output_layer(self, features, masked_tokens=None, **unused):
        return self.lm_head(features, masked_tokens, mask=unused.get("mask", None))

    def max_positions(self):
        """Maximum output length supported by the encoder."""
        return self.args.max_positions

    def upgrade_state_dict_named(self, state_dict, name):
        prefix = name + "." if name != "" else ""

        # Add architectural embeddings if not presented.
        cur_state = self.state_dict()
        if not self.args.non_parameterize:
            embed_names = [
                "rank_weight", "head_weight", "hidden_weight",
                "rank_target", "head_target", "hidden_target",
            ]
            for name in embed_names:
                full_name = name + ".weight"
                if prefix + full_name not in state_dict.keys():
                    logger.info("Adding " + prefix + full_name)
                    state_dict[prefix + full_name] = cur_state[full_name]
        else:
            mask_names = [
                "rank_mask", "head_masks", "hidden_masks"
            ]
            for name in mask_names:
                if prefix + name not in state_dict.keys():
                    logger.info("Adding " + prefix + name)
                    state_dict[prefix + name] = cur_state[name]

        if self.embed_factorize:
            # Decompose embedding via SVD
            embed_token_name = prefix + "sentence_encoder.embed_tokens.weight"
            embed_token_proj_name = prefix + "sentence_encoder.embed_tokens.projection.weight"
            if embed_token_proj_name not in state_dict.keys():
                w = state_dict[embed_token_name]
                logger.info("Decomposing embeddings via SVD")
                u, s, v_t = np.linalg.svd(w.float().numpy(), full_matrices=False)
                v = v_t.T
                u_ = torch.matmul(torch.from_numpy(u), torch.diag(torch.sqrt(torch.from_numpy(s)))).type_as(w)
                v_ = torch.matmul(torch.from_numpy(v), torch.diag(torch.sqrt(torch.from_numpy(s)))).type_as(w)
                logger.info("Overwriting " + embed_token_name)
                state_dict[embed_token_name] = u_
                logger.info("Adding " + embed_token_proj_name)
                state_dict[embed_token_proj_name] = v_

            # Decompose lm_head via SVD
            lm_head_name = prefix + "lm_head.weight"
            lm_head_proj_name = prefix + "lm_head.proj_weight"
            if lm_head_proj_name not in state_dict.keys():
                w = state_dict[lm_head_name]
                logger.info("Decomposing lm_head via SVD")
                u, s, v_t = np.linalg.svd(w.float().numpy(), full_matrices=False)
                v = v_t.T
                u_ = torch.matmul(torch.from_numpy(u), torch.diag(torch.sqrt(torch.from_numpy(s)))).type_as(w)
                v_ = torch.matmul(torch.from_numpy(v), torch.diag(torch.sqrt(torch.from_numpy(s)))).type_as(w)
                logger.info("Overwriting " + lm_head_name)
                state_dict[lm_head_name] = u_
                logger.info("Adding " + lm_head_proj_name)
                state_dict[lm_head_proj_name] = v_


@register_to("hard_concrete", _INIT_IMPL)
def hard_concrete_init(model, **kwargs):
    """
    There are two ways to initialize the weights that controls the sparsity,
    and they are exclusive to each other. The default is random initialization,
    and its results will be overwritten if the arguments of selective
    initialization are specified.

    1. Random initialization (follows a normal distribution).
    `mean`: optional (default 0), the mean of randomly initialized weights.
    `std`: optional (default 0.02), the standard deviation of randomly
    initialized weights.

    2. Selective initialization: based on the importance scores of the masked
    components, prune those unimportant ones with a given sparsity.
    `score_file`: optional (default None), the importance score file path.
    `sparsity`: optional (default 0), how much parameters will be retained
    initially. If not specified, then initializing parameters in a way that
    allows the model adapts to any given sparsity, otherwise just adapts the
    parameters to a given sparsity.
    `step`: optional (default 0.01), how small change in sparsity the model
    should pay attention to, e.g., 0.01 means the model should change if the
    change of the given sparsity is larger than or equal to 1%. This argument
    is crucial in FP16.
    """
    mean = kwargs.get("mean", 0.)
    std = kwargs.get("std", 0.02)  # simply follow BERT's initialization
    nn.init.normal_(model.encoder.rank_weight.weight, mean=mean, std=std)
    nn.init.normal_(model.encoder.head_weight.weight, mean=mean, std=std)
    nn.init.normal_(model.encoder.hidden_weight.weight, mean=mean, std=std)
    nn.init.constant_(model.encoder.rank_target.weight, 0.)
    nn.init.constant_(model.encoder.head_target.weight, 0.)
    nn.init.constant_(model.encoder.hidden_target.weight, 0.)

    score_file = kwargs.get("score_file", None)  # initialize the mask based on the given importance scores
    if score_file is not None and os.path.exists(score_file):
        temperature, zeta, gamma = 2. / 3., 1.1, -0.1

        def inverse_concrete(x):
            if not isinstance(x, torch.Tensor):
                x = torch.tensor(x)
            x = torch.clamp(x, min=0, max=1)
            return inverse_sigmoid((x - gamma) / (zeta - gamma))

        logger.info("initialize mask weights based on scores from {}".format(score_file))
        rank_importance, head_importance, hidden_importance = torch.load(
            score_file,
            map_location="cpu",
        )  # num_lang x num_head/neuron

        if model.args.lang_agnostic:
            rank_importance = rank_importance.sum(dim=0, keepdim=True) if rank_importance is not None else None
            head_importance = head_importance.sum(dim=0, keepdim=True)
            hidden_importance = hidden_importance.sum(dim=0, keepdim=True)

        # compute the mask given the desired sparsity and importance scores
        rank_weight = torch.ones_like(rank_importance)
        head_weight = torch.ones_like(head_importance) * 64 * 4
        hidden_weight = torch.ones_like(hidden_importance) * 2
        if rank_importance is not None:
            _flatten = torch.cat([rank_importance, head_importance, hidden_importance], dim=-1)
            _weighted_flatten = torch.cat(
                [rank_importance, head_importance.repeat(1, 64 * 4), hidden_importance.repeat(1, 2)], dim=-1)
            _weight = torch.cat([rank_weight, head_weight, hidden_weight], dim=-1)
        else:
            _flatten = torch.cat([head_importance, hidden_importance], dim=-1)
            _weighted_flatten = torch.cat([head_importance.repeat(1, 64 * 4), hidden_importance.repeat(1, 2)], dim=-1)
            _weight = torch.cat([head_weight, hidden_weight], dim=-1)

        if 'sparsity' in kwargs.keys():
            sparsity = kwargs.get("sparsity", 1.)  # initialize the mask with a given sparsity
            threshold = torch.from_numpy(np.quantile(_weighted_flatten.numpy(), 1 - sparsity, axis=-1, keepdims=True))
            mask = (_flatten > threshold).float()

            if rank_importance is not None:
                rank_mask = mask[:, :rank_importance.size(1)].view(rank_importance.size())
                mask = mask[:, rank_importance.size(1):]
            else:
                rank_mask = None
            head_masks = mask[:, :head_importance.size(1)].view(head_importance.size())
            hidden_masks = mask[:, head_importance.size(1):].view(hidden_importance.size())

            # find the proper initialization given a special mask
            if rank_mask is not None:
                model.encoder.rank_weight.weight.data = inverse_concrete(rank_mask).view(model.encoder.rank_weight.weight.size())
            model.encoder.head_weight.weight.data = inverse_concrete(head_masks).view(model.encoder.head_weight.weight.size())
            model.encoder.hidden_weight.weight.data = inverse_concrete(hidden_masks).view(model.encoder.hidden_weight.weight.size())
        else:  # if not specify a static sparsity ratio, then consider it to be dynamic
            _, idx = torch.sort(_flatten, dim=-1, descending=True)
            _, reverse_idx = idx.sort(dim=-1)
            score = _weight.gather(-1, idx)
            rank = score.cumsum(dim=-1)
            if model.args.fp16:
                step = kwargs.get('step', 0.01)
                threshold = torch.ones_like(_flatten)
                lowest = torch.arange(0, 1 + step, step)  # the fraction of the lowest values
                _quantiles = torch.from_numpy(np.quantile(_weighted_flatten.numpy(), lowest.numpy(), axis=-1, keepdims=True))  # the threshold of the fraction
                for i in range(lowest.size(-1)):
                    threshold[_flatten > _quantiles[i, :, :]] = 1. - lowest[i]
                delta = torch.ones_like(threshold) * step
            else:  # allow much more fine-grained sparsity, but also require higher numerical precision
                threshold = rank / rank[:, -1:]  # should fire if #retained-params > threshold
                threshold = threshold.gather(-1, reverse_idx)  # reverse sort
                delta = _weight / rank[:, -1:]
            w_t = (inverse_concrete(1.) - inverse_concrete(0.)) / delta
            w = (1. - threshold / delta) * inverse_concrete(1.) + threshold / delta * inverse_concrete(0.)

            if rank_importance is not None:
                model.encoder.rank_weight.weight.data = w[:, :rank_importance.size(1)].view(model.encoder.rank_weight.weight.size())
                model.encoder.rank_target.weight.data = w_t[:, :rank_importance.size(1)].view(model.encoder.rank_target.weight.size())
                w = w[:, rank_importance.size(1):]
                w_t = w_t[:, rank_importance.size(1):]
            model.encoder.head_weight.weight.data = w[:, :head_importance.size(1)].view(model.encoder.head_weight.weight.size())
            model.encoder.head_target.weight.data = w_t[:, :head_importance.size(1)].view(model.encoder.head_target.weight.size())
            model.encoder.hidden_weight.weight.data = w[:, head_importance.size(1):].view(model.encoder.hidden_weight.weight.size())
            model.encoder.hidden_target.weight.data = w_t[:, head_importance.size(1):].view(model.encoder.hidden_target.weight.size())


@register_model_architecture("sparse_xlmr", "sparse_xlmr")
def base_architecture(args):
    args.encoder_layers = getattr(args, "encoder_layers", 12)
    args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 768)
    args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 3072)
    args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 12)

    args.activation_fn = getattr(args, "activation_fn", "gelu")
    args.pooler_activation_fn = getattr(args, "pooler_activation_fn", "tanh")

    args.dropout = getattr(args, "dropout", 0.1)
    args.attention_dropout = getattr(args, "attention_dropout", 0.1)
    args.activation_dropout = getattr(args, "activation_dropout", 0.0)
    args.pooler_dropout = getattr(args, "pooler_dropout", 0.0)
    args.encoder_layers_to_keep = getattr(args, "encoder_layers_to_keep", None)
    args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0.0)
    args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0.0)
    args.spectral_norm_classification_head = getattr(
        args, "spectral_nrom_classification_head", False
    )


@register_model_architecture("sparse_xlmr", "sparse_xlmr_base")
def roberta_base_architecture(args):
    base_architecture(args)


@register_model_architecture("sparse_xlmr", "sparse_xlmr_large")
def roberta_large_architecture(args):
    args.encoder_layers = getattr(args, "encoder_layers", 24)
    args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024)
    args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4096)
    args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16)
    base_architecture(args)
