# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch RoBERTa model. """

import numpy as np

import torch
import torch.nn as nn
import torch.utils.checkpoint

from torch import nn
from torch.nn import CrossEntropyLoss

from transformers.activations import gelu
from transformers.file_utils import (
    add_code_sample_docstrings,
    add_start_docstrings,
    add_start_docstrings_to_model_forward,
)
from transformers.modeling_outputs import (
    BaseModelOutputWithPastAndCrossAttentions,
    BaseModelOutputWithPoolingAndCrossAttentions,
    MaskedLMOutput,
)
from transformers.utils import logging

from transformers.models.roberta.modeling_roberta import (
    RobertaEmbeddings,
    RobertaLayer,
    RobertaPreTrainedModel,
)

import copy

logger = logging.get_logger(__name__)

_CHECKPOINT_FOR_DOC = "roberta-base"
_CONFIG_FOR_DOC = "RobertaConfig"
_TOKENIZER_FOR_DOC = "RobertaTokenizer"

ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST = [
    "roberta-base",
    "roberta-large",
    "roberta-large-mnli",
    "distilroberta-base",
    "roberta-base-openai-detector",
    "roberta-large-openai-detector",
    # See all RoBERTa models at https://huggingface.co/models?filter=roberta
]


class PromptGeneratorWithNaivePromptGenerator(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.config = config
        self.num_prompt_tokens = config.num_prompt_tokens

        self.proj_down = nn.Linear(config.hidden_size, config.proj_down_size)
        self.intermediate_act_fn = nn.ReLU()
        self.proj_up = nn.Linear(config.proj_down_size, config.num_prompt_tokens * config.hidden_size)

        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, hidden_states=None, attention_mask=None):
        sequence_output = hidden_states[:, 0, :]
        hidden_states = self.proj_down(sequence_output)
        hidden_states = self.intermediate_act_fn(hidden_states)
        hidden_states = self.proj_up(hidden_states)
        hidden_states = hidden_states.view(sequence_output.shape[0], self.num_prompt_tokens, -1)
        hidden_states = self.dropout(hidden_states)

        return hidden_states 


class PromptGeneratorWithPoolingPromptGenerator(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.generator_type = config.generator_type
        self.num_prompt_tokens = config.num_prompt_tokens

        self.proj_down = nn.Linear(config.hidden_size, config.proj_down_size)
        self.intermediate_act_fn = nn.ReLU()
        if self.generator_type == 'MPPG':
            self.adaptive_pooling = nn.AdaptiveMaxPool1d(self.num_prompt_tokens)
        elif self.generator_type == 'APPG':
            self.adaptive_pooling = nn.AdaptiveAvgPool1d(self.num_prompt_tokens)
        self.proj_up = nn.Linear(config.proj_down_size, config.hidden_size)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, hidden_states=None, attention_mask=None):
        hidden_states = self.proj_down(hidden_states)

        attention_mask = attention_mask.squeeze(1).squeeze(1).contiguous()
        seq_lengths = torch.sum(attention_mask==0.0, dim=1)
        batch_prompts = []
        for i in range(hidden_states.size(0)):
            hidden_state = hidden_states[i]
            hidden_state = hidden_state[0:seq_lengths[i], :].unsqueeze(0)
            hidden_state = hidden_state.transpose(1, 2) # B x D x L
            hidden_state = (self.adaptive_pooling(hidden_state)).transpose(1, 2) # B x num_prompt_tokens x D
            batch_prompts.append(hidden_state)

        hidden_states = torch.cat(batch_prompts, dim=0)

        hidden_states = self.intermediate_act_fn(hidden_states)
        hidden_states = self.proj_up(hidden_states)
        hidden_states = self.dropout(hidden_states)

        return hidden_states


# Copied from transformers.models.bert.modeling_bert.BertEncoder with Bert->Roberta
class RobertaEncoder(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.generator_type = config.generator_type
        self.add_prompt_layer = config.add_prompt_layer
        self.prompt_layer_list = config.prompt_layer_list

        self.layer = nn.ModuleList([RobertaLayer(config) for _ in range(config.num_hidden_layers)])

    def add_prompt_generator(self):
        if self.generator_type == 'NPG':
            self.prompt_generator = PromptGeneratorWithNaivePromptGenerator(self.config)
        elif self.generator_type == 'APPG' or self.generator_type == 'MPPG':
            self.prompt_generator = PromptGeneratorWithPoolingPromptGenerator(self.config)
        else:
            raise NotImplementedError
        
    def add_multi_prompt_generators(self):
        if self.generator_type == 'NPG':
            self.prompt_generator_1 = PromptGeneratorWithNaivePromptGenerator(self.config)
            self.prompt_generator_2 = PromptGeneratorWithNaivePromptGenerator(self.config)
            self.prompt_generator_3 = PromptGeneratorWithNaivePromptGenerator(self.config)

        
        elif self.generator_type == 'APPG' or self.generator_type == 'MPPG':
            prompt_generator = PromptGeneratorWithPoolingPromptGenerator(self.config)
        else:
            raise NotImplementedError
        # # self.prompt_generator_list = [None for _ in range(self.config.num_hidden_layers)]
        # self.prompt_generator_list = []
        # for idx in self.prompt_layer_list:
        #     # self.prompt_generator_list[idx] = copy.deepcopy(prompt_generator)
        #     self.prompt_generator_list.append(copy.deepcopy(prompt_generator))
            

    def forward(
        self,
        hidden_states,
        attention_mask=None,
        head_mask=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        past_key_values=None,
        use_cache=None,
        output_attentions=False,
        output_hidden_states=False,
        return_dict=True,
        prompt_embedding=None,
        prompt_attention_mask=None,
    ):
        all_hidden_states = () if output_hidden_states else None
        all_self_attentions = () if output_attentions else None
        all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None

        assert prompt_attention_mask is not None
        # if self.generator_type is not None:
        #     assert prompt_embedding is None
        # else:
        #     assert prompt_embedding is not None

        if self.add_prompt_layer < 1 and (prompt_embedding is not None):
            hidden_states = torch.cat([prompt_embedding, hidden_states], dim=1)
            attention_mask = torch.cat([prompt_attention_mask, attention_mask], dim=-1)

        next_decoder_cache = () if use_cache else None
        for i, layer_module in enumerate(self.layer):
            if output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states,)

            layer_head_mask = head_mask[i] if head_mask is not None else None
            past_key_value = past_key_values[i] if past_key_values is not None else None

            if getattr(self.config, "gradient_checkpointing", False) and self.training:

                if use_cache:
                    logger.warning(
                        "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
                        "`use_cache=False`..."
                    )
                    use_cache = False

                def create_custom_forward(module):
                    def custom_forward(*inputs):
                        return module(*inputs, past_key_value, output_attentions)

                    return custom_forward

                layer_outputs = torch.utils.checkpoint.checkpoint(
                    create_custom_forward(layer_module),
                    hidden_states,
                    attention_mask,
                    layer_head_mask,
                    encoder_hidden_states,
                    encoder_attention_mask,
                )
            else:
                layer_outputs = layer_module(
                    hidden_states,
                    attention_mask,
                    layer_head_mask,
                    encoder_hidden_states,
                    encoder_attention_mask,
                    past_key_value,
                    output_attentions,
                )

            hidden_states = layer_outputs[0]
            if self.config.mode == 'generator_pt':
                assert self.generator_type is not None
                if (i + 1) == self.add_prompt_layer:
                    prompt_embedding = self.prompt_generator(hidden_states, attention_mask=attention_mask)
                    hidden_states = torch.cat([prompt_embedding, hidden_states], dim=1)
                    attention_mask = torch.cat([prompt_attention_mask, attention_mask], dim=-1)

            elif self.config.mode == 'vanilla_pt':
                assert isinstance(prompt_embedding, torch.Tensor)
                if (i + 1) == self.add_prompt_layer:
                    hidden_states = torch.cat([prompt_embedding, hidden_states], dim=1)
                    attention_mask = torch.cat([prompt_attention_mask, attention_mask], dim=-1)

                    

            if use_cache:
                next_decoder_cache += (layer_outputs[-1],)
            if output_attentions:
                all_self_attentions = all_self_attentions + (layer_outputs[1],)
                if self.config.add_cross_attention:
                    all_cross_attentions = all_cross_attentions + (layer_outputs[2],)

        if output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_states,)

        if not return_dict:
            return tuple(
                v
                for v in [
                    hidden_states,
                    next_decoder_cache,
                    all_hidden_states,
                    all_self_attentions,
                    all_cross_attentions,
                ]
                if v is not None
            )
        return BaseModelOutputWithPastAndCrossAttentions(
            last_hidden_state=hidden_states,
            past_key_values=next_decoder_cache,
            hidden_states=all_hidden_states,
            attentions=all_self_attentions,
            cross_attentions=all_cross_attentions,
        )


ROBERTA_START_DOCSTRING = r"""

    This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic
    methods the library implements for all its model (such as downloading or saving, resizing the input embeddings,
    pruning heads etc.)

    This model is also a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`__
    subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to
    general usage and behavior.

    Parameters:
        config (:class:`~transformers.RobertaConfig`): Model configuration class with all the parameters of the
            model. Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model
            weights.
"""

ROBERTA_INPUTS_DOCSTRING = r"""
    Args:
        input_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`):
            Indices of input sequence tokens in the vocabulary.

            Indices can be obtained using :class:`~transformers.RobertaTokenizer`. See
            :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
            details.

            `What are input IDs? <../glossary.html#input-ids>`__
        attention_mask (:obj:`torch.FloatTensor` of shape :obj:`({0})`, `optional`):
            Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:

            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **masked**.

            `What are attention masks? <../glossary.html#attention-mask>`__
        token_type_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`):
            Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0,
            1]``:

            - 0 corresponds to a `sentence A` token,
            - 1 corresponds to a `sentence B` token.

            `What are token type IDs? <../glossary.html#token-type-ids>`_
        position_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`):
            Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0,
            config.max_position_embeddings - 1]``.

            `What are position IDs? <../glossary.html#position-ids>`_
        head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
            Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``:

            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.

        inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`({0}, hidden_size)`, `optional`):
            Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
            This is useful if you want more control over how to convert :obj:`input_ids` indices into associated
            vectors than the model's internal embedding lookup matrix.
        output_attentions (:obj:`bool`, `optional`):
            Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
            tensors for more detail.
        output_hidden_states (:obj:`bool`, `optional`):
            Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
            more detail.
        return_dict (:obj:`bool`, `optional`):
            Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
"""


@add_start_docstrings(
    "The bare RoBERTa Model transformer outputting raw hidden-states without any specific head on top.",
    ROBERTA_START_DOCSTRING,
)
class RobertaModel(RobertaPreTrainedModel):
    """

    The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
    cross-attention is added between the self-attention layers, following the architecture described in `Attention is
    all you need`_ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz
    Kaiser and Illia Polosukhin.

    To behave as an decoder the model needs to be initialized with the :obj:`is_decoder` argument of the configuration
    set to :obj:`True`. To be used in a Seq2Seq model, the model needs to initialized with both :obj:`is_decoder`
    argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
    input to the forward pass.

    .. _`Attention is all you need`: https://arxiv.org/abs/1706.03762

    """

    _keys_to_ignore_on_load_missing = [r"position_ids"]

    # Copied from transformers.models.bert.modeling_bert.BertModel.__init__ with Bert->Roberta
    def __init__(self, config):
        super().__init__(config)
        self.config = config

        self.generator_type = config.generator_type
        self.add_prompt_layer = config.add_prompt_layer
        self.num_prompt_tokens = config.num_prompt_tokens

        self.embeddings = RobertaEmbeddings(config)
        self.encoder = RobertaEncoder(config)

        self.init_weights()

    def get_input_embeddings(self):
        return self.embeddings.word_embeddings

    def set_input_embeddings(self, value):
        self.embeddings.word_embeddings = value

    def _prune_heads(self, heads_to_prune):
        """
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
        class PreTrainedModel
        """
        for layer, heads in heads_to_prune.items():
            self.encoder.layer[layer].attention.prune_heads(heads)

    @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
    @add_code_sample_docstrings(
        processor_class=_TOKENIZER_FOR_DOC,
        checkpoint=_CHECKPOINT_FOR_DOC,
        output_type=BaseModelOutputWithPoolingAndCrossAttentions,
        config_class=_CONFIG_FOR_DOC,
    )
    # Copied from transformers.models.bert.modeling_bert.BertModel.forward
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        past_key_values=None,
        use_cache=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
        prompt_embedding=None,
    ):
        r"""
        encoder_hidden_states  (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
            Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
            the model is configured as a decoder.
        encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
            Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
            the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:

            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **masked**.
        past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
            Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.

            If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
            (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
            instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
        use_cache (:obj:`bool`, `optional`):
            If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
            decoding (see :obj:`past_key_values`).
        """

        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        if self.config.is_decoder:
            use_cache = use_cache if use_cache is not None else self.config.use_cache
        else:
            use_cache = False

        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
        elif input_ids is not None:
            input_shape = input_ids.size()
            batch_size, seq_length = input_shape
        elif inputs_embeds is not None:
            input_shape = inputs_embeds.size()[:-1]
            batch_size, seq_length = input_shape
        else:
            raise ValueError("You have to specify either input_ids or inputs_embeds")

        device = input_ids.device if input_ids is not None else inputs_embeds.device

        # past_key_values_length
        past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0

        if attention_mask is None:
            attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
        if token_type_ids is None:
            token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)

        # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
        # ourselves in which case we just need to make it broadcastable to all heads.
        prompt_attention_mask = torch.ones(batch_size, self.num_prompt_tokens).to(attention_mask.device)
        prompt_attention_mask = self.get_extended_attention_mask(
            prompt_attention_mask, (batch_size, self.num_prompt_tokens), device
        )

        extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)

        # If a 2D or 3D attention mask is provided for the cross-attention
        # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
        if self.config.is_decoder and encoder_hidden_states is not None:
            encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
            encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
            if encoder_attention_mask is None:
                encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
            encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
        else:
            encoder_extended_attention_mask = None

        # Prepare head mask if needed
        # 1.0 in head_mask indicate we keep the head
        # attention_probs has shape bsz x n_heads x N x N
        # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
        # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
        head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)

        embedding_output = self.embeddings(
            input_ids=input_ids,
            position_ids=position_ids,
            token_type_ids=token_type_ids,
            inputs_embeds=inputs_embeds,
            past_key_values_length=past_key_values_length,
        )

        encoder_outputs = self.encoder(
            embedding_output,
            attention_mask=extended_attention_mask,
            head_mask=head_mask,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_extended_attention_mask,
            past_key_values=past_key_values,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            prompt_embedding=prompt_embedding,
            prompt_attention_mask=prompt_attention_mask,
        )

        sequence_output = encoder_outputs[0]
        pooled_output = None

        if not return_dict:
            return (sequence_output, pooled_output) + encoder_outputs[1:]

        return BaseModelOutputWithPoolingAndCrossAttentions(
            last_hidden_state=sequence_output,
            pooler_output=pooled_output,
            past_key_values=encoder_outputs.past_key_values,
            hidden_states=encoder_outputs.hidden_states,
            attentions=encoder_outputs.attentions,
            cross_attentions=encoder_outputs.cross_attentions,
        )


class RobertaLMHead(nn.Module):
    """Roberta Head for masked language modeling."""

    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)

        self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
        self.bias = nn.Parameter(torch.zeros(config.vocab_size))

        # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
        self.decoder.bias = self.bias

    def forward(self, features, **kwargs):
        x = self.dense(features)
        x = gelu(x)
        x = self.layer_norm(x)

        # project back to size of vocabulary with bias
        x = self.decoder(x)

        return x


class RobertaPromptTuning(RobertaPreTrainedModel):
    _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
    _keys_to_ignore_on_load_unexpected = [r"pooler"]

    def __init__(self, config, initialize_from_vocab=True):
        super().__init__(config)

        if config.is_decoder:
            logger.warning(
                "If you want to use `RobertaForPromptTuning` make sure `config.is_decoder=False` for "
                "bi-directional self-attention."
            )

        self.config = config
        self.generator_type = config.generator_type
        self.ft_idx_list = config.ft_idx_list
        self.num_prompt_tokens = config.num_prompt_tokens
        self.add_prompt_layer = config.add_prompt_layer
        self.initialize_from_vocab = initialize_from_vocab

        self.roberta = RobertaModel(config)
        self.lm_head = RobertaLMHead(config)

        self.init_weights()

        self.prompt_embedding = None

    def generate_prompt_embeddings(self):
        if self.generator_type is None:
            if self.config.prompt_layer_list is None:
                if self.initialize_from_vocab:
                    self.prompt_embedding = nn.Embedding(self.num_prompt_tokens, self.config.hidden_size)
                    indices = np.random.permutation(range(5000))[:self.num_prompt_tokens]
                    init_weight = self.roberta.embeddings.word_embeddings.state_dict()["weight"][indices]
                    self.prompt_embedding._load_from_state_dict({"weight": init_weight},
                                                        "", None, True, [], [], "")
                else:
                    self.prompt_embedding = nn.Embedding(self.num_prompt_tokens, self.config.hidden_size)
            else:
                if self.initialize_from_vocab:
                    raise NotImplementedError()
                else:
                    self.prompt_embedding = nn.Embedding(self.num_prompt_tokens * len(self.config.prompt_layer_list), self.config.hidden_size)
                    # self.prompt_embedding_1 = nn.Embedding(self.num_prompt_tokens, self.config.hidden_size)
                    # self.prompt_embedding_2 = nn.Embedding(self.num_prompt_tokens, self.config.hidden_size)
        else:
            assert self.add_prompt_layer != 0
            self.prompt_embedding = None        

    def get_generator_param(self):
        return self.roberta.encoder.prompt_generator.state_dict()
        # g_list = [
        #     self.roberta.encoder.prompt_generator_1.state_dict(),
        #     self.roberta.encoder.prompt_generator_2.state_dict(),
        #     self.roberta.encoder.prompt_generator_3.state_dict(),
        # ]
        # return g_list
    def load_generator_param(self, param):
        self.roberta.encoder.prompt_generator.load_state_dict(param)
        # self.roberta.encoder.prompt_generator_1.load_state_dict(param[0])
        # self.roberta.encoder.prompt_generator_2.load_state_dict(param[1])
        # self.roberta.encoder.prompt_generator_3.load_state_dict(param[2])

    
    def get_copy_of_trainable_weights(self):
        
        state_dict = {}
        for name, p in self.named_parameters():
            # print("copy from global:", name)
            if p.requires_grad == True:
                state_dict[name] = copy.deepcopy(p.data)
        
        return state_dict


    
    def update_trainable_weights_from_dict(self, weights):
        for name, p in self.named_parameters():
            # print("restore from agg:", name)
            if p.requires_grad == True:
                p.data = copy.deepcopy(weights[name])



    def get_output_embeddings(self):
        return self.lm_head.decoder

    def set_output_embeddings(self, new_embeddings):
        self.lm_head.decoder = new_embeddings

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        labels=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
        mask_pos=None,
    ):
        r"""
        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
            Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,
            config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
            (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
        kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`):
            Used to hide legacy arguments that have been deprecated.
        """
        # print("The length of the input is {}.".format(input_ids.shape[1]))
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        prompt_embedding = None

        # 如果是vanilla（无generator）pt并且只pt一个层，那么就生成单个prompt 来更新
        if self.config.mode == 'vanilla_pt' and (self.config.prompt_layer_list is None):
            prompt_ids = torch.arange(0, self.num_prompt_tokens)
            prompt_ids = prompt_ids.view(1, -1).repeat(input_ids.size(0), 1)
            prompt_embedding = self.prompt_embedding(prompt_ids.to(input_ids.device))
        # 如果是vanilla（无generator）pt但需要多个层更新，那么就生成多个prompt
        elif self.config.mode == 'vanilla_pt' and (self.config.prompt_layer_list is not None):
            prompt_embedding = [None for _ in range(len(self.config.prompt_layer_list))]
            for idx, _ in enumerate(self.config.prompt_layer_list):
                prompt_ids = torch.arange(idx * self.num_prompt_tokens, (idx + 1) * self.num_prompt_tokens)
                # prompt_ids = torch.arange(0, self.num_prompt_tokens)
                prompt_ids = prompt_ids.view(1, -1).repeat(input_ids.size(0), 1)
                prompt_embedding[idx] = self.prompt_embedding(prompt_ids.to(input_ids.device))


        # if (self.generator_type is None) and (self.ft_idx_list is None):
        #     prompt_ids = torch.arange(0, self.num_prompt_tokens)
        #     prompt_ids = prompt_ids.view(1, -1).repeat(input_ids.size(0), 1)
        #     prompt_embedding = self.prompt_embedding(prompt_ids.to(input_ids.device))

        mask_pos = mask_pos + self.num_prompt_tokens

        outputs = self.roberta(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            prompt_embedding=prompt_embedding,
        )

        sequence_output = outputs[0]

        sequence_mask_output = sequence_output[torch.arange(sequence_output.size(0)), mask_pos]
        prediction_scores = self.lm_head(sequence_mask_output)

        masked_lm_loss = None
        if labels is not None:
            loss_fct = CrossEntropyLoss()
            masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))

        if not return_dict:
            output = (prediction_scores,) + outputs[2:]
            return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output

        return MaskedLMOutput(
            loss=masked_lm_loss,
            logits=prediction_scores,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

