from transformers import BertPreTrainedModel, BertModel, BertConfig
from transformers.file_utils import add_start_docstrings_to_callable
from transformers.modeling_bert import BERT_INPUTS_DOCSTRING, BertOnlyMLMHead
import torch
from torch import nn
import logging
logger = logging.getLogger(__name__)
SEED_NUM = 800


class RobustBert(BertPreTrainedModel):
    # 相较于BertAddExternelLoss增加了字音字形变量的输入
    def __init__(self,
                 config: BertConfig,
                 shape_embed=None,
                 pronunciation_embed=None):
        super().__init__(config)
        self.bert = BertModel(config)
        self.cls = BertOnlyMLMHead(config)

        self.enable_cls = config.enable_cls
        self.enable_pronunciation = config.enable_pronunciation
        self.enable_shape = config.enable_shape

        input_dim = config.hidden_size
        if config.concat_input:
            if self.enable_pronunciation:
                pronunciation_dim = config.pronunciation_embed_dim
                input_dim += pronunciation_dim
            if self.enable_shape:
                shape_dim = config.shape_embed_dim
                input_dim += shape_dim
            self.map_inputs_layer = torch.nn.Linear(input_dim, config.hidden_size)
            logger.info(f"init map_inputs_layer, input_dim:{input_dim}")
        else:
            self.map_inputs_layer = None

        # 这版实验也让向量先随机初始化
        if self.enable_pronunciation:
            pronunciation_vocab_size = config.pronunciation_vocab_size
            self.pronunciation_embed = nn.Embedding(pronunciation_vocab_size, config.pronunciation_embed_dim)
            if pronunciation_embed is not None:
                self.pronunciation_embed.weight = nn.Parameter(pronunciation_embed)
                self.pronunciation_embed.weight.requires_grad = False
                logger.info(f"init pronunciation embed done")
        else:
            self.pronunciation_embed = None

        if self.enable_shape:
            shape_vocab_size = config.shape_vocab_size
            self.shape_embed = nn.Embedding(shape_vocab_size,  config.shape_embed_dim)
            if shape_embed is not None:
                self.shape_embed.weight = nn.Parameter(shape_embed)
                self.shape_embed.weight.requires_grad = False
                logger.info(f"init shape embed done")
        else:
            self.shape_embed = None

        logger.info(f"cur_model_setting, enable_cls={self.enable_cls}, enable_shape={self.enable_shape}, self.enable_pronunciation={self.enable_pronunciation}")

    def get_input_plus_embedding(self, input_ids, shape_ids, pronunciation_ids, position_ids=None, token_type_ids=None):
        device = input_ids.device
        input_shape = input_ids.size()
        if token_type_ids is None:
            token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)

        embedding_word = self.bert.embeddings(input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids)
        embedding_in = torch.clone(embedding_word)
        denominator = 1

        if self.enable_shape:
            embedding_shape = self.shape_embed(shape_ids)
            embedding_in += embedding_shape
            denominator += 1
        if self.enable_pronunciation:
            embedding_pronunciation = self.pronunciation_embed(pronunciation_ids)
            embedding_in += embedding_pronunciation
            denominator += 1

        embedding_in /= denominator
        return embedding_in

    def get_input_concat_embedding(self, input_ids, shape_ids, pronunciation_ids, position_ids=None, token_type_ids=None):
        device = input_ids.device
        input_shape = input_ids.size()
        seq_length = input_shape[1]
        if position_ids is None:
            position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
            position_ids = position_ids.unsqueeze(0).expand(input_shape)
        if token_type_ids is None:
            token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)

        embedding_word = self.bert.embeddings.word_embeddings(input_ids)
        position_embeddings = self.bert.embeddings.position_embeddings(position_ids)
        token_type_embeddings = self.bert.embeddings.token_type_embeddings(token_type_ids)
        embedding_in = torch.clone(embedding_word)

        if self.enable_shape:
            embedding_shape = self.shape_embed(shape_ids)
            embedding_in = torch.cat((embedding_in, embedding_shape), -1)
        if self.enable_pronunciation:
            embedding_pronunciation = self.pronunciation_embed(pronunciation_ids)
            embedding_in = torch.cat((embedding_in, embedding_pronunciation), -1)

        embedding_in = self.map_inputs_layer(embedding_in)  # batch_size * seq_len * hidden_dim
        embedding_in = embedding_in + position_embeddings + token_type_embeddings
        embedding_in = self.bert.embeddings.LayerNorm(embedding_in)
        embedding_in = self.bert.embeddings.dropout(embedding_in)

        return embedding_in

    def get_bert_output(self, input_ids, shape_ids, pronunciation_ids, attention_mask=None, position_ids=None, head_mask=None, token_type_ids=None):
        device = input_ids.device
        input_shape = input_ids.size()

        if self.map_inputs_layer is None:
            embedding_in = self.get_input_plus_embedding(input_ids, shape_ids, pronunciation_ids, position_ids, token_type_ids)
        else:
            embedding_in = self.get_input_concat_embedding(input_ids, shape_ids, pronunciation_ids, position_ids, token_type_ids)

        encoder_outputs = self.bert.encoder(
            embedding_in,
            attention_mask=self.get_extended_attention_mask(attention_mask, input_shape, device),
            head_mask=self.get_head_mask(head_mask, self.config.num_hidden_layers),
            output_attentions=self.config.output_attentions,
            output_hidden_states=self.config.output_hidden_states,
        )
        sequence_output = encoder_outputs[0]  # batch_size*seq_len*hidden_size
        pooled_output = self.bert.pooler(sequence_output)  # batch_size*hidden_dim

        return sequence_output, pooled_output


    @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING)
    def forward(self,
                input_ids=None,
                attention_mask=None,
                input_shape_ids=None,
                input_pronunciation_ids=None,
                attack_sample_ids=None, # 先暂定只有一个， batch_size * seq_len
                attack_sample_mask=None, # 先暂定只有一个， batch_size * seq_len
                attack_sample_shape_ids=None,
                attack_sample_pronunciation_ids=None,
                labels=None,
                label_shape_ids=None,
                label_pronunciation_ids=None,
                **kwargs,
                ):
        if labels is not None:
            device = input_ids.device
            # 1. 获取input的bert_seq_out和pool_out
            sequence_output, pooled_output = self.get_bert_output(input_ids, input_shape_ids, input_pronunciation_ids, attention_mask)
            prediction_scores = self.cls(sequence_output)  # batch_size*seq_len*vocab_size
            loss_fct = torch.nn.CrossEntropyLoss()
            mlm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))

            if attack_sample_ids is not None:
                batch_size, _ = labels.shape
                device = input_ids.device

                target_inputs = torch.clone(labels)
                target_inputs[target_inputs == -100] = 0
                _, label_pooled_output = self.get_bert_output(target_inputs, label_shape_ids, label_pronunciation_ids, attention_mask)
                _, attack_pooled_output = self.get_bert_output(attack_sample_ids, attack_sample_shape_ids, attack_sample_pronunciation_ids, attack_sample_mask)
                pooled_output_norm = torch.nn.functional.normalize(pooled_output, dim=-1)   # batch_size * hidden_dim
                attack_pooled_output_norm = torch.nn.functional.normalize(attack_pooled_output, dim=-1)   # batch_size * hidden_dim
                label_pooled_output_norm = torch.nn.functional.normalize(label_pooled_output, dim=-1)  # batch_size * hidden_dim
                sim_matrix = torch.matmul(pooled_output_norm, attack_pooled_output_norm.T)   # batch_size * hidden_dim
                sim_matrix_target = torch.matmul(label_pooled_output_norm, attack_pooled_output_norm.T)
                batch_labels = torch.tensor([i for i in range(batch_size)], device=device)

                # 输入和输出和attack sample的相似度都算一算好了
                pooler_loss = (loss_fct(100 * sim_matrix.view(batch_size, -1), batch_labels.view(-1)) + loss_fct(100 * sim_matrix_target.view(batch_size, -1), batch_labels.view(-1)))/2

                return pooler_loss + mlm_loss, prediction_scores[:, 1:, :]

            return mlm_loss, prediction_scores[:, 1:, :]
        else:
            sequence_output, pooled_output = self.get_bert_output(input_ids, input_shape_ids, input_pronunciation_ids, attention_mask)
            prediction_scores = self.cls(sequence_output)  # batch_size*seq_len*vocab_size
            return prediction_scores[0]


class RobustBertAdvanceClassification(RobustBert):
    # 增加cls
    def __init__(self, config: BertConfig):
        super().__init__(config)
        self.num_labels = config.num_labels
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)
        self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)


    @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING)
    def forward(self,
                input_ids=None,
                input_shape_ids=None,
                input_pronunciation_ids=None,
                attention_mask=None,
                token_type_ids=None,
                labels=None,
                ):

        sequence_output, pooled_output = self.get_bert_output(input_ids=input_ids, shape_ids=input_shape_ids,
                                                                   pronunciation_ids=input_pronunciation_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)   # batch_size * num_labels

        if labels is not None:
            loss_fct = torch.nn.CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
            return loss, logits
        else:
            return logits

