# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch BERT model."""

from __future__ import absolute_import, division, print_function, unicode_literals

import copy
import json
import math
import os
import shutil
import tarfile
import tempfile
import sys
from io import open

import torch
import torch.nn.functional as F
from torch import nn
from torch.nn import CrossEntropyLoss
from torch.autograd import Variable
from torch.nn.parameter import Parameter

import transformers
from transformers import BertPreTrainedModel, BertModel
from transformers.models.bert.modeling_bert import BertEmbeddings, BertEncoder, BertPooler, BertLayer
from transformers.models.bert.modeling_bert import BertPreTrainingHeads

import logging
logger = logging.getLogger(__name__)

logger.warn('Hacking BertSelfAttention! Now it returns attention scores rather than probabilities.')

class BertSelfAttention(transformers.models.bert.modeling_bert.BertSelfAttention):

    def forward(
        self,
        hidden_states,
        attention_mask=None,
        head_mask=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        past_key_value=None,
        output_attentions=False,
    ):
        mixed_query_layer = self.query(hidden_states)

        # most codes are copied from transformers v4.3.3
        
        is_cross_attention = encoder_hidden_states is not None

        if is_cross_attention and past_key_value is not None:
            # reuse k,v, cross_attentions
            key_layer = past_key_value[0]
            value_layer = past_key_value[1]
            attention_mask = encoder_attention_mask
        elif is_cross_attention:
            key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
            value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
            attention_mask = encoder_attention_mask
        elif past_key_value is not None:
            key_layer = self.transpose_for_scores(self.key(hidden_states))
            value_layer = self.transpose_for_scores(self.value(hidden_states))
            key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
            value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
        else:
            key_layer = self.transpose_for_scores(self.key(hidden_states))
            value_layer = self.transpose_for_scores(self.value(hidden_states))

        query_layer = self.transpose_for_scores(mixed_query_layer)

        if self.is_decoder:
            past_key_value = (key_layer, value_layer)

        # Take the dot product between "query" and "key" to get the raw attention scores.
        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))

        if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
            seq_length = hidden_states.size()[1]
            position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
            position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
            distance = position_ids_l - position_ids_r
            positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
            positional_embedding = positional_embedding.to(dtype=query_layer.dtype)  # fp16 compatibility

            if self.position_embedding_type == "relative_key":
                relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
                attention_scores = attention_scores + relative_position_scores
            elif self.position_embedding_type == "relative_key_query":
                relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
                relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
                attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key

        attention_scores = attention_scores / math.sqrt(self.attention_head_size)
        if attention_mask is not None:
            # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
            attention_scores = attention_scores + attention_mask

        # Normalize the attention scores to probabilities.
        attention_probs = nn.Softmax(dim=-1)(attention_scores)

        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
        attention_probs = self.dropout(attention_probs)

        # Mask heads if we want to
        if head_mask is not None:
            attention_probs = attention_probs * head_mask
            #attention_scores = attention_scores * head_mask

        context_layer = torch.matmul(attention_probs, value_layer)

        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(*new_context_layer_shape)

        outputs = (context_layer, attention_scores) if output_attentions else (context_layer,) # hacked: replace attention_probs with attention_scores

        if self.is_decoder:
            outputs = outputs + (past_key_value,)
        return outputs
    
transformers.models.bert.modeling_bert.BertSelfAttention = BertSelfAttention



class BertForPreTraining(BertPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        fit_size = 768 #config.fit_size
        self.bert = BertModel(config)
        self.cls = BertPreTrainingHeads(config)
        self.fit_denses = nn.ModuleList(
            [nn.Linear(config.hidden_size, fit_size) for _ in range(7)]
        )
#         self.apply(self.init_bert_weights)

    def forward(self, input_ids, token_type_ids=None,
                attention_mask=None, masked_lm_labels=None,
                next_sentence_label=None, labels=None,
                output_attentions=True, output_hidden_states=True,):
        outputs = self.bert(
            input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states)
        sequence_output, att_output, pooled_output = outputs.hidden_states, outputs.attentions, outputs.pooler_output
        tmp = []
        for s_id, sequence_layer in enumerate(sequence_output):
            tmp.append(self.fit_denses[s_id](sequence_layer))
        sequence_output = tmp

        return att_output, sequence_output

    
    
    
class ShallowSkipping(nn.Module):
    
    def __init__(self, model):
        super().__init__()
        self.model = (model,) # do not register
        self.config = model.config
        self.shallow_config = model.shallow_config
        self.left = self.config.n_gram_left
        self.right = self.config.n_gram_right
        self.ngram = self.left + 1 + self.right
        
        if self.shallow_config.hidden_size != self.config.hidden_size:
            self.linear = nn.Linear(self.shallow_config.hidden_size, self.config.hidden_size)
        
    def build_input_ngrams(self, input_ids, token_type_ids):
        
        left = self.left
        right = self.right
        ngram = self.ngram
        
        original_index = (input_ids != 0)
        
        input_ngram_ids = input_ids
        input_ngram_ids = F.pad(input_ngram_ids, (left, right), "constant", -1)
        pad_selection = input_ngram_ids!=0
        input_ngram_ids = input_ngram_ids[pad_selection] # flatten
        input_ngram_ids = input_ngram_ids.unfold(0,ngram,1) # unfold
        
        token_ngram_type_ids = None #
        
        attention_mask = (input_ngram_ids > 0).float()
        
        if self.training:
            _mask = torch.rand(attention_mask.shape).to(attention_mask.device)
            _mask = (_mask > self.config.ngram_masking)
            _mask[:, 1:] = 1
            attention_mask *= _mask

        attention_mask[:, input_ngram_ids.size(-1)//2] = 1 # avoid masking all tokens in a tri-gram
        return input_ngram_ids, token_ngram_type_ids, attention_mask, original_index
    
    def merge_ngrams(
        self, hidden_states, attention_mask,
        batch_size, seq_length, 
        aux_embeddings=0,
        original_index=None,
        ngram_index=None,
    ):
        
        mid = self.ngram // 2
        
        model = self.model[0]
        norm = model.norm
        attn = model.attn
        
        hidden_states[(~ngram_index).unsqueeze(1).repeat(1, self.ngram)] = 0.
        
        if self.shallow_config.hidden_size != self.config.hidden_size:
            hidden_states = self.linear(hidden_states)
            
        attention_mask = attention_mask.type(hidden_states.dtype)
        
        ngram_hidden_state = hidden_states * attn(hidden_states).sigmoid() * attention_mask.unsqueeze(-1)
        flat_hidden_state = ngram_hidden_state[:, mid] # (seq0 + seq1 + ...)

        for j in range(1, self.left+1):
            flat_hidden_state[:-j] = flat_hidden_state[:-j] + ngram_hidden_state[j:, mid-j]
        for j in range(1, self.right+1):
            flat_hidden_state[j:] = flat_hidden_state[j:] + ngram_hidden_state[:-j, mid+j]
            
        hidden_state = torch.zeros([batch_size, seq_length, hidden_states.size(-1)], dtype=hidden_states.dtype, device=hidden_states.device)
        hidden_state[original_index] = flat_hidden_state[ngram_index]
        
        hidden_state = norm(hidden_state + aux_embeddings)
        
        return hidden_state
    
    def forward(
        self, 
        input_ids,
        token_type_ids,
        head_mask=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        past_key_value=None,
        output_attentions=True,
        output_hidden_states=True,
    ):
        
        model = self.model[0]
        
        device = input_ids.device
        
        seq_length = input_ids.size(1)
        position_ids = torch.arange(seq_length, dtype=torch.long, device=device)  # (max_seq_length)
        position_ids = position_ids.unsqueeze(0).expand_as(input_ids)  # (bs, max_seq_length)
        aux_embeddings = model.embeddings.position_embeddings2(position_ids)  # (bs, max_seq_length, dim)
        aux_embeddings = aux_embeddings + model.embeddings.token_type_embeddings2(token_type_ids)
        
        batch_size, seq_length = input_ids.shape
        input_ngram_ids, token_ngram_type_ids, attention_mask, original_index = self.build_input_ngrams(input_ids, token_type_ids)
        ngram_attention_mask = attention_mask
        
        input_ids = input_ngram_ids.clone() # batch_size * seq_len, ngram
        input_ids[input_ids<0] = 0
        extended_attention_mask = model.get_extended_attention_mask(attention_mask, input_ngram_ids.shape, device)
        
        ngram_index=(input_ngram_ids[:, input_ngram_ids.size(-1)//2]>0)
        
        embedding_output = model.embeddings(input_ids=input_ids, token_type_ids=token_ngram_type_ids)
        
        hidden_states = embedding_output
        attention_mask = extended_attention_mask
        
        for i, layer_module in enumerate(model.encoder.layer[:self.config.num_hidden_layers - self.config.num_full_hidden_layers]):
            
            layer_head_mask = head_mask[i] if head_mask is not None else None

            layer_outputs = layer_module(
                hidden_states=hidden_states,
                attention_mask=attention_mask,
                head_mask=layer_head_mask,
                encoder_hidden_states=encoder_hidden_states,
                encoder_attention_mask=encoder_attention_mask,
                past_key_value=past_key_value,
                output_attentions=output_attentions,
            )
        
            hidden_states = layer_outputs[0]
                    
        hidden_states = self.merge_ngrams(
            hidden_states, ngram_attention_mask,
            batch_size, seq_length,
            aux_embeddings=aux_embeddings,
            original_index=original_index,
            ngram_index=ngram_index,
        )
            
        return hidden_states
    
    
class SkipBertModel(BertModel):
    def __init__(self, config, add_pooling_layer=True):
        super().__init__(config)
        self.config = config
        self.shallow_config = copy.deepcopy(config)
        self.shallow_config.hidden_size = 768
        self.shallow_config.intermediate_size = 3072

        self.embeddings = BertEmbeddings(self.shallow_config)
        self.encoder = SkipBertEncoder(self.shallow_config, config)

        self.pooler = BertPooler(config) if add_pooling_layer else None

        self.init_weights()
        
class SkipBertEncoder(BertEncoder):
    def __init__(self, shallow_config, config):
        super(BertEncoder, self).__init__()
        self.config = config
        self.shallow_config = shallow_config
        self.layer = nn.ModuleList(
            [
                BertLayer(shallow_config) for _ in range(config.num_hidden_layers - config.num_full_hidden_layers)
            ] + [
                BertLayer(config) for _ in range(config.num_full_hidden_layers)
            ])
    
class SkipBertModel(SkipBertModel):
    
    def __init__(self, *nargs, **kargs):
        super().__init__(*nargs, **kargs)
        self.norm = nn.LayerNorm(self.config.hidden_size)
        self.attn = nn.Linear(self.config.hidden_size, 1)
        self.embeddings.position_embeddings2 = nn.Embedding(self.config.max_position_embeddings, self.config.hidden_size)
        self.embeddings.token_type_embeddings2 = nn.Embedding(self.config.type_vocab_size, self.config.hidden_size)
        
        self.shallow_skipping = ShallowSkipping(self)
            
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        output_attentions=True,
        output_hidden_states=True,
        return_dict=False,
    ):
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = False
        assert output_hidden_states
        assert not return_dict

        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
        elif input_ids is not None:
            input_shape = input_ids.size()
        elif inputs_embeds is not None:
            input_shape = inputs_embeds.size()[:-1]
        else:
            raise ValueError("You have to specify either input_ids or inputs_embeds")

        device = input_ids.device if input_ids is not None else inputs_embeds.device

        if attention_mask is None:
            attention_mask = (input_ids != 0).float()
        if token_type_ids is None:
            token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)

        # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
        # ourselves in which case we just need to make it broadcastable to all heads.
        extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)

        # Prepare head mask if needed
        # 1.0 in head_mask indicate we keep the head
        # attention_probs has shape bsz x n_heads x N x N
        # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
        # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
        head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
        
        
        # Local transformer layers
        hidden_states = self.shallow_skipping(
            input_ids=input_ids,
            token_type_ids=token_type_ids,
            head_mask=head_mask,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
        )
        
        # Global transformer layers
        attention_mask = extended_attention_mask
        
        all_hidden_states = () if output_hidden_states else None
        all_self_attentions = () if output_attentions else None
        past_key_value = None
        
        for i, layer_module in enumerate(self.encoder.layer[-self.config.num_full_hidden_layers:]):
            
            if output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states,)

            layer_head_mask = head_mask[i + self.config.num_hidden_layers - self.config.num_full_hidden_layers] if head_mask is not None else None

            layer_outputs = layer_module(
                hidden_states=hidden_states,
                attention_mask=attention_mask,
                head_mask=layer_head_mask,
                encoder_hidden_states=encoder_hidden_states,
                encoder_attention_mask=encoder_attention_mask,
                past_key_value=past_key_value,
                output_attentions=output_attentions,
            )
        
            hidden_states = layer_outputs[0]
                
            if output_attentions:
                all_self_attentions = all_self_attentions + (layer_outputs[1],)

        if output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_states,)

        sequence_output = hidden_states
        pooled_output = self.pooler(sequence_output)
        
        return (sequence_output, pooled_output, all_hidden_states, all_self_attentions)
    
    

class SkipBertForPreTraining(BertPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.fit_size = 768
        self.bert = SkipBertModel(config)
        self.cls = BertPreTrainingHeads(config)
        
        if self.fit_size != config.hidden_size:
            self.fit_denses = nn.ModuleList(
                [nn.Linear(config.hidden_size, self.fit_size) for _ in range(config.num_hidden_layers + 1)]
            )

    def forward(self, input_ids, token_type_ids=None,
                attention_mask=None, masked_lm_labels=None,
                next_sentence_label=None, labels=None,
                output_attentions=True, output_hidden_states=True,):
        _, pooled_output, sequence_output, att_output = self.bert(
            input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask,
            output_attentions=output_attentions, output_hidden_states=output_hidden_states,)
        
        if self.fit_size != self.config.hidden_size:
            tmp = []
            for s_id, sequence_layer in enumerate(sequence_output):
                tmp.append(self.fit_denses[s_id](sequence_layer))
            sequence_output = tmp

        return att_output, sequence_output