# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch RoBERTa model. """


import logging
import warnings

import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss, MSELoss

from .configuration_roberta import RobertaConfig
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
from .modeling_bert import BertEmbeddings, BertLayerNorm, BertModel, BertPreTrainedModel, gelu, BertPooler, BertLayer


import torch.nn.functional as F


logger = logging.getLogger(__name__)

_CONFIG_FOR_DOC = "RobertaConfig"
_TOKENIZER_FOR_DOC = "RobertaTokenizer"

ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST = [
    "roberta-base",
    "roberta-large",
    "roberta-large-mnli",
    "distilroberta-base",
    "roberta-base-openai-detector",
    "roberta-large-openai-detector",
    # See all RoBERTa models at https://huggingface.co/models?filter=roberta
]


class RobertaEmbeddings(BertEmbeddings):
    """
    Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
    """

    def __init__(self, config):
        super().__init__(config)
        self.padding_idx = 1 #config.pad_token_id
        self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=self.padding_idx)
        self.position_embeddings = nn.Embedding(
            config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx
        )

    def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
        if position_ids is None:
            if input_ids is not None:
                # Create the position ids from the input token ids. Any padded tokens remain padded.
                position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx).to(input_ids.device)
            else:
                position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds)

        return super().forward(
            input_ids, token_type_ids=token_type_ids, position_ids=position_ids, inputs_embeds=inputs_embeds
        )

    def create_position_ids_from_inputs_embeds(self, inputs_embeds):
        """ We are provided embeddings directly. We cannot infer which are padded so just generate
        sequential position ids.

        :param torch.Tensor inputs_embeds:
        :return torch.Tensor:
        """
        input_shape = inputs_embeds.size()[:-1]
        sequence_length = input_shape[1]

        position_ids = torch.arange(
            self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device
        )
        return position_ids.unsqueeze(0).expand(input_shape)


ROBERTA_START_DOCSTRING = r""""""

ROBERTA_INPUTS_DOCSTRING = r""""""





class DeskepelerEncoder(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
        self.ke_head = nn.Sequential(nn.Linear(config.hidden_size, config.hidden_size),
                                        nn.Tanh())

    def forward(
        self,
        hidden_states,
        attention_mask=None,
        head_mask=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        output_attentions=False,
        output_hidden_states=False,
        ke_embed=None,
        ke_valid=None
    ):
        all_hidden_states = () if output_hidden_states else None
        all_attentions = () if output_attentions else None

        for i, layer_module in enumerate(self.layer):

            if i==6 and ke_embed is not None:
                ke_encode = self.ke_head(ke_embed)  # (bsz, extraK, dim)
                hidden_states = torch.cat([hidden_states, ke_encode], dim=1)
                ke_valid = ke_valid.unsqueeze(1).unsqueeze(1).to(attention_mask)
                attention_mask = torch.cat([attention_mask, ke_valid], dim=-1)

            if output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states,)

            layer_outputs = layer_module(
                hidden_states, attention_mask, head_mask[i], encoder_hidden_states, encoder_attention_mask
            )
            
            hidden_states = layer_outputs[0]
            if output_attentions:
                all_attentions = all_attentions + (layer_outputs[1],)

        if output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_states,)


        return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)





class DeskepelerModel(BertPreTrainedModel):
    """
    This class overrides :class:`~transformers.BertModel`. Please check the
    superclass for the appropriate documentation alongside usage examples.
    """

    config_class = RobertaConfig
    base_model_prefix = "roberta"

    def __init__(self, config):
        super().__init__(config)

        self.embeddings = RobertaEmbeddings(config)

        self.encoder = DeskepelerEncoder(config)
        self.pooler = BertPooler(config)

        self.init_weights()


    def get_input_embeddings(self):
        return self.embeddings.word_embeddings

    def set_input_embeddings(self, value):
        self.embeddings.word_embeddings = value

    def _prune_heads(self, heads_to_prune):
        """ Prunes heads of the model.
            heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
            See base class PreTrainedModel
        """
        for layer, heads in heads_to_prune.items():
            self.encoder.layer[layer].attention.prune_heads(heads)
            
    def get_extended_attention_mask(self, attention_mask, input_shape, device):
        """
        Makes broadcastable attention and causal masks so that future and masked tokens are ignored.

        Arguments:
            attention_mask (:obj:`torch.Tensor`):
                Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
            input_shape (:obj:`Tuple[int]`):
                The shape of the input to the model.
            device: (:obj:`torch.device`):
                The device of the input to the model.

        Returns:
            :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
        """
        # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
        # ourselves in which case we just need to make it broadcastable to all heads.
        if attention_mask.dim() == 3:
            extended_attention_mask = attention_mask[:, None, :, :]
        elif attention_mask.dim() == 2:
            # Provided a padding mask of dimensions [batch_size, seq_length]
            # - if the model is a decoder, apply a causal mask in addition to the padding mask
            # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
            if self.config.is_decoder:
                batch_size, seq_length = input_shape
                seq_ids = torch.arange(seq_length, device=device)
                causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
                # causal and attention masks must have same type with pytorch version < 1.3
                causal_mask = causal_mask.to(attention_mask.dtype)
                extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
            else:
                extended_attention_mask = attention_mask[:, None, None, :]
        else:
            raise ValueError(
                "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
                    input_shape, attention_mask.shape
                )
            )

        # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
        # masked positions, this operation will create a tensor which is 0.0 for
        # positions we want to attend and -10000.0 for masked positions.
        # Since we are adding it to the raw scores before the softmax, this is
        # effectively the same as removing these entirely.
        extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype)  # fp16 compatibility
        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
        return extended_attention_mask

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        output_attentions=None,
        output_hidden_states=None,
        # return_tuple=None,
        ke_embed=None,
        ke_valid=None
    ):
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        # return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple

        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
        elif input_ids is not None:
            input_shape = input_ids.size()
        elif inputs_embeds is not None:
            input_shape = inputs_embeds.size()[:-1]
        else:
            raise ValueError("You have to specify either input_ids or inputs_embeds")

        device = input_ids.device if input_ids is not None else inputs_embeds.device

        if attention_mask is None:
            attention_mask = torch.ones(input_shape, device=device)
        if token_type_ids is None:
            token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)

        # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
        # ourselves in which case we just need to make it broadcastable to all heads.
        extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, device)

        # If a 2D ou 3D attention mask is provided for the cross-attention
        # we need to make broadcastabe to [batch_size, num_heads, seq_length, seq_length]
        if self.config.is_decoder and encoder_hidden_states is not None:
            encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
            encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
            if encoder_attention_mask is None:
                encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
            encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
        else:
            encoder_extended_attention_mask = None

        # Prepare head mask if needed
        # 1.0 in head_mask indicate we keep the head
        # attention_probs has shape bsz x n_heads x N x N
        # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
        # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
        # head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
        if head_mask is not None:
            if head_mask.dim() == 1:
                head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
                head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1)
            elif head_mask.dim() == 2:
                head_mask = (
                    head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)
                )  # We can specify head_mask for each layer
            head_mask = head_mask.to(
                dtype=next(self.parameters()).dtype
            )  # switch to fload if need + fp16 compatibility
        else:
            head_mask = [None] * self.config.num_hidden_layers
            
        embedding_output = self.embeddings(
            input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
        )
        encoder_outputs = self.encoder(
            embedding_output,
            attention_mask=extended_attention_mask,
            head_mask=head_mask,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_extended_attention_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            ke_embed=ke_embed,
            ke_valid=ke_valid
        )
        sequence_output = encoder_outputs[0]
        pooled_output = self.pooler(sequence_output)


        return (sequence_output, pooled_output) + encoder_outputs[1:]





# class DeskepelerMergeEncoder(nn.Module):
#     def __init__(self, config):
#         super().__init__()
#         self.config = config
#         self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
#         self.des_head = DesPooler(config)

#         self.ke_head = nn.Sequential(nn.Linear(config.hidden_size, config.hidden_size),
#                                         nn.Tanh())

#     def forward(
#         self,
#         hidden_states,
#         attention_mask=None,
#         head_mask=None,
#         encoder_hidden_states=None,
#         encoder_attention_mask=None,
#         output_attentions=False,
#         output_hidden_states=False,
#         bsz=None,
#         seq_len=None,
#         ke_valid=None,
#     ):
#         all_hidden_states =  None
#         all_attentions = None
#         split_layer = 6

#         for i, layer_module in enumerate(self.layer):

#             if i==split_layer and bsz is not None:
#                 inputs_part = hidden_states[:bsz, :seq_len, :]
#                 inputs_mask = attention_mask[:bsz:, :, :, :seq_len]
#                 hidden_states = hidden_states[bsz: ]
#                 attention_mask = attention_mask[bsz: ]

#             if output_hidden_states:
#                 all_hidden_states = all_hidden_states + (hidden_states,)

#             layer_outputs = layer_module(
#                 hidden_states, attention_mask, None, None, None, # head_mask[i], encoder_hidden_states, encoder_attention_mask
#             )
            
#             hidden_states = layer_outputs[0]


#         if bsz is not None:
#             ke_output = hidden_states
#             ke_embed = self.des_head(ke_output)

#             ke_encode = self.ke_head(ke_embed[:ke_valid.shape[0]*ke_valid.shape[1]])
#             ke_encode = ke_encode.view(ke_valid.shape[0], ke_valid.shape[1], -1)
#             seq_len = inputs_part.shape[1]
#             inputs_part = torch.cat([inputs_part, ke_encode], dim=1)
#             ke_valid = ke_valid.unsqueeze(1).unsqueeze(1).to(inputs_mask)
#             inputs_mask = torch.cat([inputs_mask, ke_valid], dim=-1)

#             for i in range(split_layer, len(self.layer)):
#                 layer_module = self.layer[i]
#                 layer_outputs = layer_module(
#                     inputs_part, inputs_mask, None, None, None,
#                 )
                
#                 inputs_part = layer_outputs[0]
#             return (inputs_part[:,:seq_len], ke_embed)
#         else:
 
#             assert(False)

#             return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)


# class DeskepelerMergeModel(BertPreTrainedModel):
#     """
#     This class overrides :class:`~transformers.BertModel`. Please check the
#     superclass for the appropriate documentation alongside usage examples.
#     """

#     config_class = RobertaConfig
#     base_model_prefix = "roberta"

#     def __init__(self, config):
#         super().__init__(config)

#         self.embeddings = RobertaEmbeddings(config)

#         self.encoder = DeskepelerMergeEncoder(config)
#         self.pooler = BertPooler(config)

#         self.init_weights()


#     def get_input_embeddings(self):
#         return self.embeddings.word_embeddings

#     def set_input_embeddings(self, value):
#         self.embeddings.word_embeddings = value

#     def _prune_heads(self, heads_to_prune):
#         """ Prunes heads of the model.
#             heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
#             See base class PreTrainedModel
#         """
#         for layer, heads in heads_to_prune.items():
#             self.encoder.layer[layer].attention.prune_heads(heads)
            
#     def get_extended_attention_mask(self, attention_mask, input_shape, device):
#         """
#         Makes broadcastable attention and causal masks so that future and masked tokens are ignored.

#         Arguments:
#             attention_mask (:obj:`torch.Tensor`):
#                 Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
#             input_shape (:obj:`Tuple[int]`):
#                 The shape of the input to the model.
#             device: (:obj:`torch.device`):
#                 The device of the input to the model.

#         Returns:
#             :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
#         """
#         # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
#         # ourselves in which case we just need to make it broadcastable to all heads.
#         if attention_mask.dim() == 3:
#             extended_attention_mask = attention_mask[:, None, :, :]
#         elif attention_mask.dim() == 2:
#             # Provided a padding mask of dimensions [batch_size, seq_length]
#             # - if the model is a decoder, apply a causal mask in addition to the padding mask
#             # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
#             if self.config.is_decoder:
#                 batch_size, seq_length = input_shape
#                 seq_ids = torch.arange(seq_length, device=device)
#                 causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
#                 # causal and attention masks must have same type with pytorch version < 1.3
#                 causal_mask = causal_mask.to(attention_mask.dtype)
#                 extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
#             else:
#                 extended_attention_mask = attention_mask[:, None, None, :]
#         else:
#             raise ValueError(
#                 "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
#                     input_shape, attention_mask.shape
#                 )
#             )

#         # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
#         # masked positions, this operation will create a tensor which is 0.0 for
#         # positions we want to attend and -10000.0 for masked positions.
#         # Since we are adding it to the raw scores before the softmax, this is
#         # effectively the same as removing these entirely.
#         extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype)  # fp16 compatibility
#         extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
#         return extended_attention_mask

#     def forward(
#         self,
#         input_ids=None,
#         attention_mask=None,
#         token_type_ids=None,
#         position_ids=None,
#         head_mask=None,
#         inputs_embeds=None,
#         encoder_hidden_states=None,
#         encoder_attention_mask=None,
#         output_attentions=None,
#         output_hidden_states=None,
#         # return_tuple=None,
#         ke_input_ids=None,
#         ke_attention_mask=None,
#         ke_valid=None,
#     ):
#         device = input_ids.device if input_ids is not None else inputs_embeds.device

#         if ke_input_ids is not None:
#             bsz, seq_len = input_ids.size()
#             one_pad = torch.zeros((bsz, ke_input_ids.shape[1]-seq_len), device=device)
#             zero_pad = torch.zeros((bsz, ke_input_ids.shape[1]-seq_len), device=device)

#             input_ids = torch.cat([input_ids, one_pad.to(input_ids)], dim=1)
#             attention_mask = torch.cat([attention_mask, zero_pad.to(attention_mask)], dim=1)

#             input_ids = torch.cat([input_ids, ke_input_ids], dim=0)
#             attention_mask = torch.cat([attention_mask, ke_attention_mask], dim=0)


#         output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
#         output_hidden_states = (
#             output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
#         )
#         # return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple

#         if input_ids is not None and inputs_embeds is not None:
#             raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
#         elif input_ids is not None:
#             input_shape = input_ids.size()
#         elif inputs_embeds is not None:
#             input_shape = inputs_embeds.size()[:-1]
#         else:
#             raise ValueError("You have to specify either input_ids or inputs_embeds")


#         if attention_mask is None:
#             attention_mask = torch.ones(input_shape, device=device)
#         if token_type_ids is None:
#             token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)

#         # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
#         # ourselves in which case we just need to make it broadcastable to all heads.
#         extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, device)


#         # If a 2D ou 3D attention mask is provided for the cross-attention
#         # we need to make broadcastabe to [batch_size, num_heads, seq_length, seq_length]
#         if self.config.is_decoder and encoder_hidden_states is not None:
#             encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
#             encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
#             if encoder_attention_mask is None:
#                 encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
#             encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
#         else:
#             encoder_extended_attention_mask = None

#         # Prepare head mask if needed
#         # 1.0 in head_mask indicate we keep the head
#         # attention_probs has shape bsz x n_heads x N x N
#         # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
#         # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
#         # head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
#         if head_mask is not None:
#             if head_mask.dim() == 1:
#                 head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
#                 head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1)
#             elif head_mask.dim() == 2:
#                 head_mask = (
#                     head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)
#                 )  # We can specify head_mask for each layer
#             head_mask = head_mask.to(
#                 dtype=next(self.parameters()).dtype
#             )  # switch to fload if need + fp16 compatibility
#         else:
#             head_mask = [None] * self.config.num_hidden_layers


#         embedding_output = self.embeddings(
#             input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
#         )
#         encoder_outputs = self.encoder(
#             embedding_output,
#             attention_mask=extended_attention_mask,
#             head_mask=head_mask,
#             encoder_hidden_states=encoder_hidden_states,
#             encoder_attention_mask=encoder_extended_attention_mask,
#             output_attentions=output_attentions,
#             output_hidden_states=output_hidden_states,
#             bsz=bsz,
#             seq_len=seq_len,
#             ke_valid=ke_valid
#         )
#         return encoder_outputs 
#         # sequence_output = encoder_outputs[0]
#         # pooled_output = self.pooler(sequence_output)


#         # return (sequence_output, pooled_output) + encoder_outputs[1:]




class RobertaLMHead(nn.Module):
    """Roberta Head for masked language modeling."""

    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.layer_norm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)

        self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
        self.bias = nn.Parameter(torch.zeros(config.vocab_size))

        # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
        self.decoder.bias = self.bias

    def forward(self, features, **kwargs):
        x = self.dense(features)
        x = gelu(x)
        x = self.layer_norm(x)

        # project back to size of vocabulary with bias
        x = self.decoder(x)

        return x



# class DesPooler(nn.Module):
#     def __init__(self, config):
#         super().__init__()
#         self.dense = nn.Linear(config.hidden_size, config.hidden_size)
#         self.activation = nn.Tanh()

#     def forward(self, hidden_states):
#         # We "pool" the model by simply taking the hidden state corresponding
#         # to the first token.
#         first_token_tensor = hidden_states[:, 0]
#         pooled_output = self.dense(first_token_tensor)
#         pooled_output = self.activation(pooled_output)
#         return pooled_output



class DeskepelerForPretraining(BertPreTrainedModel):
    config_class = RobertaConfig
    base_model_prefix = "roberta"

    def __init__(self, config):
        super().__init__(config)

        self.roberta = DeskepelerModel(config)
        self.lm_head = RobertaLMHead(config)

        self.p = 2
        self.margin = 2
        self.init_weights()

    def get_output_embeddings(self):
        return self.lm_head.decoder


    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
        output_attentions=None,
        output_hidden_states=None,
        # return_tuple=None,
        ke_ids=None,
        ke_mask=None,
        ke_valid=None,
        # **kwargs
        h_ids=None,
        h_mask=None,
        t_ids=None,
        t_mask=None,
        # neg_entity=None,
        # h_neg_valid=None,
        # t_neg_valid=None,
        cooccur_ids=None,
        cooccur_mask=None,
        ke_entity_ids=None, 
        # ke_embed=None,
        h_entity_ids=None,
        t_entity_ids=None,
    ):

        # MLM part
        bsz, extraK, seq_len = ke_ids.shape
        ke_ids = ke_ids.view(bsz*extraK, seq_len)
        ke_mask = ke_mask.view(bsz*extraK, seq_len)

        ke_ids_plus = torch.cat([ke_ids, h_ids, t_ids, cooccur_ids], dim=0)
        ke_mask_plus = torch.cat([ke_mask, h_mask, t_mask, cooccur_mask], dim=0)

        
        ke_outputs_plus = self.roberta(
            ke_ids_plus,
            attention_mask=ke_mask_plus,
        )
        ke_embed_plus = ke_outputs_plus[1]
        
        ke_embed = ke_embed_plus[:bsz*extraK]
        h = ke_embed_plus[bsz*extraK: bsz*extraK+bsz]
        t = ke_embed_plus[bsz*extraK+bsz: bsz*extraK+bsz*2]
        r = ke_embed_plus[bsz*extraK+bsz*2: ]
        
        
        seq_len = input_ids.shape[1]
        outputs = self.roberta(
            input_ids,
            attention_mask=attention_mask,
            ke_embed=ke_embed.view(bsz, extraK, -1),
            ke_valid=ke_valid
        )
        sequence_output = outputs[0]
        sequence_output = sequence_output[:, :seq_len, :]
        prediction_scores = self.lm_head(sequence_output)

        loss_fct = CrossEntropyLoss(ignore_index=-1)
        masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))

        pRep = h + r - t

        pScores = self.margin - torch.norm(pRep, p=self.p, dim=-1)

        pLoss = F.logsigmoid(pScores)

        # (bsz, K, dim)
        ke_entity_ids = torch.cat([ke_entity_ids.view(-1), h_entity_ids.view(-1), t_entity_ids.view(-1)])
        neg_entity = torch.cat([ke_embed, h, t])

        _ke_entity_ids = ke_entity_ids.unsqueeze(0)
        neg_valid = ( h_entity_ids.unsqueeze(1)!=_ke_entity_ids ) & ( t_entity_ids.unsqueeze(1)!=_ke_entity_ids) & (_ke_entity_ids >=0 )
        neg_valid = torch.cat([neg_valid, neg_valid], dim=-1)
        neg_valid = neg_valid.to(pLoss)

        replace_h_score = neg_entity + (r - t).unsqueeze(1)
        replace_t_score = (h + r).unsqueeze(1) - neg_entity
        nRep = torch.cat([replace_h_score, replace_t_score], dim=1)
        nScores = self.margin - torch.norm(nRep, p=self.p, dim=-1)
        nLoss = F.logsigmoid(-nScores) 
        nLoss = torch.sum(nLoss * neg_valid, dim=-1) / (torch.sum(neg_valid, dim=-1)+1e-6)

        ke_loss = (-pLoss.mean()-nLoss.mean()) /2.0

        loss = masked_lm_loss + ke_loss

        return (loss, masked_lm_loss, ke_loss)


        # # MLM part
        # bsz, extraK, seq_len = ke_ids.shape
        # ke_ids = ke_ids.view(bsz*extraK, seq_len)
        # ke_mask = ke_mask.view(bsz*extraK, seq_len)

        # ke_ids_plus = torch.cat([ke_ids, h_ids, t_ids, cooccur_ids], dim=0)
        # ke_mask_plus = torch.cat([ke_mask, h_mask, t_mask, cooccur_mask], dim=0)

        # ke_outputs = self.roberta(
        #     input_ids=input_ids,
        #     attention_mask=attention_mask,
        #     ke_input_ids=ke_ids_plus,
        #     ke_attention_mask=ke_mask_plus,
        #     ke_valid=ke_valid,
        # )
    
        # sequence_output = ke_outputs[0]
        # ke_embed_plus = ke_outputs[1]

        # ke_embed = ke_embed_plus[:bsz*extraK]
        # h = ke_embed_plus[bsz*extraK: bsz*extraK+bsz]
        # t = ke_embed_plus[bsz*extraK+bsz: bsz*extraK+bsz*2]
        # r = ke_embed_plus[bsz*extraK+bsz*2: ]
        

        # sequence_output = sequence_output[:, :seq_len, :]
        # prediction_scores = self.lm_head(sequence_output)

        # loss_fct = CrossEntropyLoss(ignore_index=-1)
        # masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))

        # pRep = h + r - t

        # pScores = self.margin - torch.norm(pRep, p=self.p, dim=-1)

        # pLoss = F.logsigmoid(pScores)

        # # (bsz, K, dim)
        # ke_entity_ids = torch.cat([ke_entity_ids.view(-1), h_entity_ids.view(-1), t_entity_ids.view(-1)])
        # neg_entity = torch.cat([ke_embed, h, t])

        # _ke_entity_ids = ke_entity_ids.unsqueeze(0)
        # neg_valid = ( h_entity_ids.unsqueeze(1)!=_ke_entity_ids ) & ( t_entity_ids.unsqueeze(1)!=_ke_entity_ids) & (_ke_entity_ids >=0 )
        # neg_valid = torch.cat([neg_valid, neg_valid], dim=-1)
        # neg_valid = neg_valid.to(pLoss)

        # replace_h_score = neg_entity + (r - t).unsqueeze(1)
        # replace_t_score = (h + r).unsqueeze(1) - neg_entity
        # nRep = torch.cat([replace_h_score, replace_t_score], dim=1)
        # nScores = self.margin - torch.norm(nRep, p=self.p, dim=-1)
        # nLoss = F.logsigmoid(-nScores) 
        # nLoss = torch.sum(nLoss * neg_valid, dim=-1) / (torch.sum(neg_valid, dim=-1)+1e-6)

        # ke_loss = (-pLoss.mean()-nLoss.mean()) /2.0

        # loss = masked_lm_loss + ke_loss

        # return (loss, masked_lm_loss, ke_loss)


        # ke_embed = self.des_head(ke_outputs[0])
        # ke_encode = self.ke_head(ke_embed)
        # ke_encode = ke_encode.view(bsz, extraK, -1)



        # return (loss, masked_lm_loss, ke_loss)


        # bsz, extraK, seq_len = ke_ids.shape
        # ke_ids = ke_ids.view(bsz*extraK, seq_len)
        # ke_mask = ke_mask.view(bsz*extraK, seq_len)

        # ke_ids_plus = torch.cat([ke_ids, h_ids, t_ids, cooccur_ids], dim=0)
        # ke_mask_plus = torch.cat([ke_mask, h_mask, t_mask, cooccur_mask], dim=0)

        
        # ke_outputs_plus = self.roberta(
        #     ke_ids_plus,
        #     attention_mask=ke_mask_plus,
        # )
        # ke_embed_plus = self.des_head(ke_outputs_plus[0])
        
        # ke_embed = ke_embed_plus[:bsz*extraK]
        # h = ke_embed_plus[bsz*extraK: bsz*extraK+bsz]
        # t = ke_embed_plus[bsz*extraK+bsz: bsz*extraK+bsz*2]
        # r = ke_embed_plus[bsz*extraK+bsz*2: ]
        
        # ke_encode = self.ke_head(ke_embed)  # (bsz*extraK, dim)
        # ke_encode = ke_encode.view(bsz, extraK, -1)
        
        # seq_len = input_ids.shape[1]
        # outputs = self.roberta(
        #     input_ids,
        #     attention_mask=attention_mask,
        #     ke_encode=ke_encode,
        #     ke_valid=ke_valid
        # )
        # sequence_output = outputs[0]
        # sequence_output = sequence_output[:, :seq_len, :]
        # prediction_scores = self.lm_head(sequence_output)

        # loss_fct = CrossEntropyLoss(ignore_index=-1)
        # masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))

        # pRep = h + r - t

        # pScores = self.margin - torch.norm(pRep, p=self.p, dim=-1)

        # pLoss = F.logsigmoid(pScores)

        # # (bsz, K, dim)
        # ke_entity_ids = torch.cat([ke_entity_ids.view(-1), h_entity_ids.view(-1), t_entity_ids.view(-1)])
        # neg_entity = torch.cat([ke_embed, h, t])

        # _ke_entity_ids = ke_entity_ids.unsqueeze(0)
        # neg_valid = ( h_entity_ids.unsqueeze(1)!=_ke_entity_ids ) & ( t_entity_ids.unsqueeze(1)!=_ke_entity_ids) & (_ke_entity_ids >=0 )
        # neg_valid = torch.cat([neg_valid, neg_valid], dim=-1)
        # neg_valid = neg_valid.to(pLoss)

        # replace_h_score = neg_entity + (r - t).unsqueeze(1)
        # replace_t_score = (h + r).unsqueeze(1) - neg_entity
        # nRep = torch.cat([replace_h_score, replace_t_score], dim=1)
        # nScores = self.margin - torch.norm(nRep, p=self.p, dim=-1)
        # nLoss = F.logsigmoid(-nScores) 
        # nLoss = torch.sum(nLoss * neg_valid, dim=-1) / (torch.sum(neg_valid, dim=-1)+1e-6)

        # ke_loss = (-pLoss.mean()-nLoss.mean()) /2.0

        # loss = masked_lm_loss + ke_loss

        # return (loss, masked_lm_loss, ke_loss)


# class DeskepelerForPretraining(BertPreTrainedModel):
#     config_class = RobertaConfig
#     base_model_prefix = "roberta"

#     def __init__(self, config):
#         super().__init__(config)

#         self.roberta = DeskepelerModel(config)
#         self.lm_head = RobertaLMHead(config)
#         self.des_head = DesPooler(config)

#         self.ke_head = nn.Sequential(nn.Linear(config.hidden_size, config.hidden_size),
#                                         nn.Tanh())

#         self.p = 2
#         self.margin = 2
#         self.init_weights()

#     def get_output_embeddings(self):
#         return self.lm_head.decoder


#     def forward(
#         self,
#         input_ids=None,
#         attention_mask=None,
#         token_type_ids=None,
#         position_ids=None,
#         head_mask=None,
#         inputs_embeds=None,
#         labels=None,
#         output_attentions=None,
#         output_hidden_states=None,
#         # return_tuple=None,
#         ke_ids=None,
#         ke_mask=None,
#         ke_valid=None,
#         # **kwargs
#         h_ids=None,
#         h_mask=None,
#         t_ids=None,
#         t_mask=None,
#         # neg_entity=None,
#         # h_neg_valid=None,
#         # t_neg_valid=None,
#         cooccur_ids=None,
#         cooccur_mask=None,
#         ke_entity_ids=None, 
#         # ke_embed=None,
#         h_entity_ids=None,
#         t_entity_ids=None,
#     ):
                
#         # MLM part
#         bsz, extraK, seq_len = ke_ids.shape
#         ke_ids = ke_ids.view(bsz*extraK, seq_len)
#         ke_mask = ke_mask.view(bsz*extraK, seq_len)

#         ke_outputs = self.roberta(
#             ke_ids,
#             attention_mask=ke_mask,
#         )
#         ke_embed = self.des_head(ke_outputs[0])
#         ke_encode = self.ke_head(ke_embed)
#         ke_encode = ke_encode.view(bsz, extraK, -1)

#         seq_len = input_ids.shape[1]
#         outputs = self.roberta(
#             input_ids,
#             attention_mask=attention_mask,
#             ke_encode=ke_encode,
#             ke_valid=ke_valid
#         )
#         sequence_output = outputs[0]
#         sequence_output = sequence_output[:, :seq_len, :]
#         prediction_scores = self.lm_head(sequence_output)


#         loss_fct = CrossEntropyLoss(ignore_index=-1)
#         masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))

#         # KE part

#         h_outputs = self.roberta(
#             h_ids,
#             attention_mask=h_mask
#         )
#         h = self.des_head(h_outputs[0])  # (bsz, dim)

#         t_outputs = self.roberta(
#             t_ids,
#             attention_mask=t_mask
#         )
#         t = self.des_head(t_outputs[0])


#         des_outputs = self.roberta(
#             cooccur_ids,
#             attention_mask=cooccur_mask
#         )
#         r = self.des_head(des_outputs[0])

#         pRep = h + r - t

#         pScores = self.margin - torch.norm(pRep, p=self.p, dim=-1)

#         pLoss = F.logsigmoid(pScores)


#         # (bsz, K, dim)
#         ke_entity_ids = torch.cat([ke_entity_ids.view(-1), h_entity_ids.view(-1), t_entity_ids.view(-1)])
#         neg_entity = torch.cat([ke_embed, h, t])

#         _ke_entity_ids = ke_entity_ids.unsqueeze(0)
#         neg_valid = ( h_entity_ids.unsqueeze(1)!=_ke_entity_ids ) & ( t_entity_ids.unsqueeze(1)!=_ke_entity_ids) & (_ke_entity_ids >=0 )
#         neg_valid = torch.cat([neg_valid, neg_valid], dim=-1)
#         neg_valid = neg_valid.to(pLoss)

#         replace_h_score = neg_entity + (r - t).unsqueeze(1)
#         replace_t_score = (h + r).unsqueeze(1) - neg_entity
#         nRep = torch.cat([replace_h_score, replace_t_score], dim=1)
#         nScores = self.margin - torch.norm(nRep, p=self.p, dim=-1)
#         nLoss = F.logsigmoid(-nScores) 
#         nLoss = torch.sum(nLoss * neg_valid, dim=-1) / (torch.sum(neg_valid, dim=-1)+1e-6)

#         ke_loss = (-pLoss.mean()-nLoss.mean()) /2.0

#         loss = masked_lm_loss + ke_loss

#         return (loss, masked_lm_loss, ke_loss)




class DeskepelerForKEembed(BertPreTrainedModel):
    config_class = RobertaConfig
    base_model_prefix = "roberta"

    def __init__(self, config):
        super().__init__(config)

        self.roberta = DeskepelerModel(config)
        self.lm_head = RobertaLMHead(config)
#         self.des_head = DesPooler(config)

#         self.ke_head = nn.Sequential(nn.Linear(config.hidden_size, config.hidden_size),
#                                         nn.Tanh())


        self.init_weights()

    def get_output_embeddings(self):
        return self.lm_head.decoder


    def forward(
        self,
        ke_ids=None,
        ke_mask=None
    ):
        ke_outputs = self.roberta(
            ke_ids,
            attention_mask=ke_mask,
        )
#         ke_embed = self.des_head(ke_outputs[0])
#         ke_encode = self.ke_head(ke_outputs[1])

        return (ke_outputs[1], )



class DeskepelerForSequenceClassification(BertPreTrainedModel):
    
    config_class = RobertaConfig
    base_model_prefix = "roberta"

    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels

        self.roberta = DeskepelerModel(config)
        self.classifier = RobertaClassificationHead(config)

        self.ke_embed = nn.Embedding(87665, config.hidden_size, padding_idx=0)
        self.init_weights()
        
    def init_ke_embed(self):
        word_mat = torch.load('data/fewrel_ke_embed.th')
        self.ke_embed.weight.data.copy_(word_mat)
        self.ke_embed.weight.requires_grad = False
        print ('loaded KE embed')

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
        ke_entity_ids=None,
        ke_valid=None,
    ):
        ke_embed = self.ke_embed(ke_entity_ids)
        outputs = self.roberta(
            input_ids,
            attention_mask=attention_mask,
            ke_embed=ke_embed,
            ke_valid=ke_valid            
        )
        sequence_output = outputs[0]
        logits = self.classifier(sequence_output)

        outputs = (logits,) + outputs[2:]
        if labels is not None:
            if self.num_labels == 1:
                #  We are doing regression
                loss_fct = MSELoss()
                loss = loss_fct(logits.view(-1), labels.view(-1))
            else:
                loss_fct = CrossEntropyLoss()
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
            outputs = (loss,) + outputs

        return outputs  # (loss), logits, (hidden_states), (attentions)


    
class DeskepelerForQuestionAnswering(BertPreTrainedModel):
    config_class = RobertaConfig
    base_model_prefix = "roberta"

    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels

        self.roberta = DeskepelerModel(config)
        self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)

        self.init_weights()

        self.ke_embed = nn.Embedding(87665, config.hidden_size, padding_idx=0)
        self.init_weights()
        
    def init_ke_embed(self):
        word_mat = torch.load('data/triviaqa_ke_embed.th')
        self.ke_embed.weight.data.copy_(word_mat)
        self.ke_embed.weight.requires_grad = False
        print ('loaded KE embed')
        
        
    def forward(
        self,
        input_ids,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        start_positions=None,
        end_positions=None,
        ke_entity_ids=None,
        ke_valid=None,
    ):
        ke_encode = self.ke_embed(ke_entity_ids)
        outputs = self.roberta(
            input_ids,
            attention_mask=attention_mask,
            ke_encode=ke_encode,
            ke_valid=ke_valid                        
        )

        sequence_output = outputs[0]

        logits = self.qa_outputs(sequence_output)
        start_logits, end_logits = logits.split(1, dim=-1)
        start_logits = start_logits.squeeze(-1)
        end_logits = end_logits.squeeze(-1)

        outputs = (start_logits, end_logits,) + outputs[2:]
        if start_positions is not None and end_positions is not None:
            # If we are on multi-GPU, split add a dimension
            if len(start_positions.size()) > 1:
                start_positions = start_positions.squeeze(-1)
            if len(end_positions.size()) > 1:
                end_positions = end_positions.squeeze(-1)
            # sometimes the start/end positions are outside our model inputs, we ignore these terms
            ignored_index = start_logits.size(1)
            start_positions.clamp_(0, ignored_index)
            end_positions.clamp_(0, ignored_index)

            loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
            start_loss = loss_fct(start_logits, start_positions)
            end_loss = loss_fct(end_logits, end_positions)
            total_loss = (start_loss + end_loss) / 2
            outputs = (total_loss,) + outputs

        return outputs  # (loss), start_logits, end_logits, (hidden_states), (attentions)
    


class DeskepelerForTriviaQuestionAnswering(BertPreTrainedModel):
    config_class = RobertaConfig
    base_model_prefix = "roberta"

    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels

        self.roberta = DeskepelerModel(config)
        self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)

        self.init_weights()

    def or_softmax_cross_entropy_loss_one_doc(self, logits, target, ignore_index=-1, dim=-1):
        """loss function suggested in section 2.2 here https://arxiv.org/pdf/1710.10723.pdf"""
        assert logits.ndim == 2
        assert target.ndim == 2
        assert logits.size(0) == target.size(0)

        # with regular CrossEntropyLoss, the numerator is only one of the logits specified by the target
        # here, the numerator is the sum of a few potential targets, where some of them is the correct answer

        # compute a target mask
        target_mask = target == ignore_index
        # replaces ignore_index with 0, so `gather` will select logit at index 0 for the msked targets
        masked_target = target * (1 - target_mask.long())
        # gather logits
        gathered_logits = logits.gather(dim=dim, index=masked_target)
        # Apply the mask to gathered_logits. Use a mask of -inf because exp(-inf) = 0
        gathered_logits[target_mask] = -10000.0#float('-inf')

        # each batch is one example
        gathered_logits = gathered_logits.view(1, -1)
        logits = logits.view(1, -1)

        # numerator = log(sum(exp(gathered logits)))
        log_score = torch.logsumexp(gathered_logits, dim=dim, keepdim=False)
        # denominator = log(sum(exp(logits)))
        log_norm = torch.logsumexp(logits, dim=dim, keepdim=False)

        # compute the loss
        loss = -(log_score - log_norm)

        # some of the examples might have a loss of `inf` when `target` is all `ignore_index`.
        # remove those from the loss before computing the sum. Use sum instead of mean because
        # it is easier to compute
        return loss[~torch.isinf(loss)].sum()
    

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        start_positions=None,
        end_positions=None,
        answer_masks=None,
        ke_entity_ids=None,
        ke_valid=None,
    ):
        ke_encode = self.ke_embed(ke_entity_ids)
        outputs = self.roberta(
            input_ids,
            attention_mask=attention_mask,
            ke_encode=ke_encode,
            ke_valid=ke_valid                        
        )


        sequence_output = outputs[0]

        logits = self.qa_outputs(sequence_output)
        start_logits, end_logits = logits.split(1, dim=-1)
        start_logits = start_logits.squeeze(-1)
        end_logits = end_logits.squeeze(-1)

        outputs = (start_logits, end_logits,) + outputs[2:]

        if start_positions is not None and end_positions is not None:
            # If we are on multi-GPU, split add a dimension
            # if len(start_positions.size()) > 1:
            #     start_positions = start_positions.squeeze(-1)
            # if len(end_positions.size()) > 1:
            #     end_positions = end_positions.squeeze(-1)

            start_loss = self.or_softmax_cross_entropy_loss_one_doc(start_logits, start_positions, ignore_index=-1)
            end_loss = self.or_softmax_cross_entropy_loss_one_doc(end_logits, end_positions, ignore_index=-1)
            # loss_fct = CrossEntropyLoss(ignore_index=-1, reduce=False)

            # start_losses = [loss_fct(start_logits, _start_positions) for _start_positions in torch.unbind(start_positions, dim=1)]
            # end_losses = [loss_fct(end_logits, _end_positions) for _end_positions in torch.unbind(end_positions, dim=1)]

            # total_loss = sum(start_losses + end_losses)
            # total_loss = torch.mean(total_loss) / 2

            total_loss = (start_loss + end_loss) / 2

            outputs = (total_loss,) + outputs

        return outputs  

class RobertaClassificationHead(nn.Module):
    """Head for sentence-level classification tasks."""

    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.out_proj = nn.Linear(config.hidden_size, config.num_labels)

    def forward(self, features, **kwargs):
        x = features[:, 0, :]  # take <s> token (equiv. to [CLS])
        x = self.dropout(x)
        x = self.dense(x)
        x = torch.tanh(x)
        x = self.dropout(x)
        x = self.out_proj(x)
        return x



def create_position_ids_from_input_ids(input_ids, padding_idx):
    """ Replace non-padding symbols with their position numbers. Position numbers begin at
    padding_idx+1. Padding symbols are ignored. This is modified from fairseq's
    `utils.make_positions`.

    :param torch.Tensor x:
    :return torch.Tensor:
    """
    # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.

    mask = input_ids.ne(padding_idx).int()
    incremental_indices = torch.cumsum(mask, dim=1).type_as(mask) * mask
    return incremental_indices.long() + padding_idx
