import torch
import torch.nn as nn
import torch.nn.functional as F
import copy

from onmt.BertModules import *
import pdb, numpy
from onmt.GraphBert import GATResMergerLayer
from torch.utils.checkpoint import checkpoint


class knowledge_integration(nn.Module):
    def __init__(self, config, bert_config, content_embedding):
        super(knowledge_integration, self).__init__()
        self.config = config
        self.bert_config = bert_config
        self.linear = nn.Linear(config.content_dim, bert_config.hidden_size)
        # self.sentinel = nn.Parameter(torch.randn(1, config.content_dim), requires_grad=True)
        self.content_embedding = nn.Embedding.from_pretrained(content_embedding, freeze=True)
        self.softmax = nn.Softmax(2)
        self.linear2 = nn.Linear(bert_config.hidden_size, config.content_dim)

        # self.slices = torch.ones(self.config.batch_size, self.config.seq_len, 1, 1)

    def forward(self, bert_output, content_ids, content_mask):
        #pdb.set_trace()
        batch_size = bert_output.size()[0]
        sentinel = torch.gather(bert_output.transpose(1, 2), 2, torch.LongTensor(
            [[[0] for _ in range(self.bert_config.hidden_size)] for _ in range(batch_size)]).cuda())
        sentinel = sentinel.transpose(1, 2)
        sentinel = self.linear2(sentinel)
        sentinel = sentinel.expand((batch_size, self.config.seq_len, self.config.content_dim))
        sentinel = torch.unsqueeze(sentinel, 2)

        contents = self.content_embedding(content_ids)  # batch_size * sqe_len * num_content * content_dim

        content_all = torch.cat((contents, sentinel), 2)  # batch_size * sqe_len * num_content+1 * content_dim

        content_trans = self.linear(content_all)

        bert_content = torch.unsqueeze(bert_output, 3)
        atten_score = torch.matmul(content_trans, bert_content)

        tmp_content_mask = (1.0 - content_mask) * -10000.0
        tmp_content_mask = tmp_content_mask.to(atten_score.dtype)
        content_mask = content_mask.to(atten_score.dtype)

        atten_score = torch.squeeze(atten_score, 3)
        # atten_score = torch.mul(atten_score, content_mask)
        atten_score = atten_score + tmp_content_mask

        atten_weight = self.softmax(atten_score)  # batch_size * sqe_len * num_content+1
        atten_weight = torch.mul(atten_weight, content_mask)
        atten_weight = torch.unsqueeze(atten_weight, 2)
        output = torch.matmul(atten_weight, content_all)
        output = torch.squeeze(output, 2)
        # output = torch.cat((bert_output, output), 2)  # batch_size * sqe_len * content_dim + hidden_size

        return output


class self_matching(nn.Module):
    def __init__(self, config, bert_config):
        super(self_matching, self).__init__()
        self.config = config
        self.bert_config = bert_config
        if config.wordnet and config.nell:
            self.linear = nn.Linear(3 * (config.content_dim * 2 + bert_config.hidden_size), 1)
        else:
            self.linear = nn.Linear(3 * (config.content_dim + bert_config.hidden_size), 1)
        self.softmax = nn.Softmax(2)

    def forward(self, integration_output, attention_mask):
        batch_size = integration_output.size()[0]
        if self.config.wordnet and self.config.nell:
            embed_dim = self.config.content_dim * 2 + self.bert_config.hidden_size
        else:
            embed_dim = self.config.content_dim + self.bert_config.hidden_size

        integration = torch.unsqueeze(integration_output, 2)
        expand = integration.expand(batch_size, self.config.seq_len, self.config.seq_len, embed_dim)

        integration2 = torch.unsqueeze(integration_output, 1)
        expand2 = integration2.expand(batch_size, self.config.seq_len, self.config.seq_len, embed_dim)

        extended_attention_mask = attention_mask.unsqueeze(1)
        extended_attention_mask = extended_attention_mask.to(integration_output.dtype)
        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0

        # attention_mask = torch.unsqueeze(attention_mask, 1)
        # attention_mask = attention_mask.expand(batch_size, self.config.seq_len, self.config.seq_len)
        # attention_mask = attention_mask.to(integration_output.dtype)

        element_wise_mul = torch.mul(expand, expand2)
        final_expand = torch.cat((expand, expand2, element_wise_mul), 3)

        atten_score = self.linear(final_expand)
        atten_score = torch.squeeze(atten_score, 3)

        # atten_score = torch.mul(atten_score, attention_mask)
        atten_score = atten_score + extended_attention_mask
        atten_weight = self.softmax(atten_score)

        V1 = torch.matmul(atten_weight, integration_output)

        atten_weight2 = torch.matmul(atten_weight, atten_weight)
        # atten_weight2 = torch.mul(atten_weight2, attention_mask)

        V2 = torch.matmul(atten_weight2, integration_output)
        # V1,V2: batch_size * seq_len * content_dim + hidden_size

        UV = torch.mul(integration_output, V1)
        final_representation = torch.cat(
            (integration_output, V1, integration_output - V1, UV, V2, integration_output - V2), 2)

        return final_representation


class merge(nn.Module):
    def __init__(self, config, bert_config):
        super(merge, self).__init__()
        self.config = config
        if config.wordnet and config.nell:
            self.linear = nn.Linear(6 * (config.content_dim * 2 + bert_config.hidden_size), bert_config.hidden_size)
        else:
            self.linear = nn.Linear(6 * (config.content_dim + bert_config.hidden_size), bert_config.hidden_size)
        self.multihead_atten = GATResMergerLayer(bert_config)

    def forward(self, match_output, bert_output, attention_mask=None):
        match_output = self.linear(match_output)

        merge_output = self.multihead_atten(bert_output, match_output, sent_ind=None, attention_mask=attention_mask)

        return merge_output


class ktnet_encoder(nn.Module):
    def __init__(self, config, bert_config, wordnet_embed=None, nell_embed=None):
        super(ktnet_encoder, self).__init__()
        bert_layer = BertLayer(config, False, False)
        # bert_layer = bert_ignore_last_arg(config)
        self.config = config
        self.bert_config = bert_config
        if config.wordnet and config.nell:
            self.integrate1 = knowledge_integration(config, bert_config, wordnet_embed)
            self.integrate2 = knowledge_integration(config, bert_config, nell_embed)
        elif config.wordnet:
            self.integrate = knowledge_integration(config, bert_config, wordnet_embed)
        elif config.nell:
            self.integrate = knowledge_integration(config, bert_config, nell_embed)
        else:
            raise ValueError('at least one of Wordnet and Nell must be True')

        self.match = self_matching(config, bert_config)
        self.bert_layers = nn.ModuleList([copy.deepcopy(bert_layer) for _ in range(config.num_hidden_layers)])
        self.merge = merge(config, bert_config)

    def forward(self, input_embedding, attention_mask, wordnet_content_id=None, nell_content_id=None, wn_mask=None,
                ne_mask=None):
        all_hidden_layers = []
        start_layer = self.config.start_layer - 1
        merge_layer = self.config.merge_layer - 1
        num_layers = len(self.bert_layers)


        extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
        extended_attention_mask = extended_attention_mask.to(input_embedding.dtype)
        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0


        hidden_state = input_embedding
        for i in range(merge_layer):
            # hidden_state = checkpoint(self.bert_layers[i], hidden_state, attention_mask, self.arg)
            hidden_state = self.bert_layers[i](hidden_state, extended_attention_mask)
            all_hidden_layers.append(hidden_state)

        start_layer_output = all_hidden_layers[start_layer]

        if self.config.wordnet and self.config.nell:
            integration1 = self.integrate1(start_layer_output, wordnet_content_id, wn_mask)
            integration2 = self.integrate2(start_layer_output, nell_content_id, ne_mask)

            # integration1 = checkpoint(self.integrate1, start_layer_output, wordnet_content_id, wn_mask, self.arg)
            # integration2 = checkpoint(self.integrate2, start_layer_output, nell_content_id, ne_mask, self.arg)

            integration = torch.cat((start_layer_output, integration1, integration2), 2)

        elif self.config.wordnet:
            integration = self.integrate(start_layer_output, wordnet_content_id, wn_mask)
            # integration = checkpoint(self.integrate, start_layer_output, wordnet_content_id, wn_mask, self.arg)
            integration = torch.cat((start_layer_output, integration), 2)

        else:
            integration = self.integrate(start_layer_output, nell_content_id, ne_mask)
            # integration = checkpoint(self.integrate, start_layer_output, nell_content_id, ne_mask, self.arg)
            integration = torch.cat((start_layer_output, integration), 2)

        matching = self.match(integration, attention_mask)
        # matching = checkpoint(self.match, integration, self.arg)

        merge_output = self.merge(matching, all_hidden_layers[-1], attention_mask=extended_attention_mask)
        # merge_output = checkpoint(self.merge, matching, all_hidden_layers[-1], self.arg)

        hidden_state = merge_output
        for j in range(merge_layer, num_layers):
            # hidden_state = checkpoint(self.bert_layers[j], hidden_state, attention_mask, self.arg)
            hidden_state = self.bert_layers[j](hidden_state, extended_attention_mask)
        return hidden_state


class output_layer(nn.Module):
    def __init__(self, config, bert_config):
        super(output_layer, self).__init__()
        self.config = config
        self.linear1 = nn.Linear(bert_config.hidden_size, 1)
        self.linear2 = nn.Linear(bert_config.hidden_size, 1)

    def forward(self, bert_output):
        # pos1 = torch.squeeze(self.linear1(bert_output), 2)
        # pos2 = torch.squeeze(self.linear2(bert_output), 2)

        pos1 = self.linear1(bert_output).squeeze(2)
        pos2 = self.linear2(bert_output).squeeze(2)

        return pos1, pos2


class ktnet(BertPreTrainedModel):
    def __init__(self, config, bert_config, wordnet_embed=None, nell_embed=None):
        super(ktnet, self).__init__(config)
        self.config = config
        self.bert_config = bert_config
        self.embeddings = BertEmbeddings(bert_config)
        # self.embeddings = embedding_ignore_last_arg(embeddings)

        self.encoder = ktnet_encoder(config, bert_config, wordnet_embed=wordnet_embed, nell_embed=nell_embed)
        self.output = output_layer(config, bert_config)

    def forward(self, input_ids, attention_mask, seg_ids, wordnet_content_id=None, nell_content_id=None, wn_mask=None,
                ne_mask=None):

        embedding_output = self.embeddings(input_ids, token_type_ids=seg_ids)

        encoder = self.encoder(embedding_output, attention_mask, wordnet_content_id=wordnet_content_id,
                               nell_content_id=nell_content_id, wn_mask=wn_mask, ne_mask=ne_mask)

        start_pos, end_pos = self.output(encoder)
        # start_pos, end_pos = checkpoint(self.output, encoder, self.arg)

        return start_pos, end_pos


class ktnet_encoder_baseline(BertPreTrainedModel):
    def __init__(self, config, bert_config, wordnet_embed=None, nell_embed=None):
        super(ktnet_encoder_baseline, self).__init__(config)
        bert_layer = BertLayer(config, False, False)
        # bert_layer = bert_ignore_last_arg(config)
        self.config = config
        self.bert_config = bert_config
        if config.wordnet and config.nell:
            self.integrate1 = knowledge_integration(config, bert_config, wordnet_embed)
            self.integrate2 = knowledge_integration(config, bert_config, nell_embed)
        elif config.wordnet:
            self.integrate = knowledge_integration(config, bert_config, wordnet_embed)
        elif config.nell:
            self.integrate = knowledge_integration(config, bert_config, nell_embed)
        else:
            raise ValueError('at least one of Wordnet and Nell must be True')

        self.match = self_matching(config, bert_config)
        self.bert_layers = nn.ModuleList([copy.deepcopy(bert_layer) for _ in range(config.num_hidden_layers)])

    def forward(self, input_embedding, attention_mask, wordnet_content_id=None, nell_content_id=None, wn_mask=None,
                ne_mask=None):
        all_hidden_layers = []

        extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
        extended_attention_mask = extended_attention_mask.to(input_embedding.dtype)

        hidden_state = input_embedding
        for i in range(len(self.bert_layers)):
            hidden_state = self.bert_layers[i](hidden_state, extended_attention_mask)
            all_hidden_layers.append(hidden_state)

        start_layer_output = all_hidden_layers[-1]

        if self.config.wordnet and self.config.nell:
            integration1 = self.integrate1(start_layer_output, wordnet_content_id, wn_mask)
            integration2 = self.integrate2(start_layer_output, nell_content_id, ne_mask)
            integration = torch.cat((start_layer_output, integration1, integration2), 2)

        elif self.config.wordnet:
            integration = self.integrate(start_layer_output, wordnet_content_id, wn_mask)
            integration = torch.cat((start_layer_output, integration), 2)

        else:
            integration = self.integrate(start_layer_output, nell_content_id, ne_mask)
            integration = torch.cat((start_layer_output, integration), 2)

        matching = self.match(integration, attention_mask)
        #pdb.set_trace()
        return matching


class output_baseline(nn.Module):
    def __init__(self, config, bert_config):
        super(output_baseline, self).__init__()
        self.config = config
        if config.wordnet and config.nell:
            self.linear1 = nn.Linear(6*(bert_config.hidden_size+2*config.content_dim), 1)
            self.linear2 = nn.Linear(6*(bert_config.hidden_size+2*config.content_dim), 1)
        else:
            self.linear1 = nn.Linear(6*(bert_config.hidden_size+config.content_dim), 1)
            self.linear2 = nn.Linear(6*(bert_config.hidden_size+config.content_dim), 1)
    
    def forward(self, bert_output):
        pos1 = self.linear1(bert_output).squeeze(2)
        pos2 = self.linear2(bert_output).squeeze(2)

        return pos1, pos2


class ktnet_baseline(BertPreTrainedModel):
    def __init__(self, config, bert_config, wordnet_embed=None, nell_embed=None):
        super(ktnet_baseline, self).__init__(config)
        self.config = config
        self.bert_config = bert_config
        self.embeddings = BertEmbeddings(bert_config)

        self.encoder = ktnet_encoder_baseline(config, bert_config, wordnet_embed=wordnet_embed, nell_embed=nell_embed)
        self.output = output_baseline(config, bert_config)

    def forward(self, input_ids, attention_mask, seg_ids, wordnet_content_id=None, nell_content_id=None, wn_mask=None,
                ne_mask=None):

        embedding_output = self.embeddings(input_ids, token_type_ids=seg_ids)

        encoder = self.encoder(embedding_output, attention_mask, wordnet_content_id=wordnet_content_id,
                               nell_content_id=nell_content_id, wn_mask=wn_mask, ne_mask=ne_mask)
        
        start_pos, end_pos = self.output(encoder)

        return start_pos, end_pos


class bert_base(BertPreTrainedModel):
    def __init__(self, config, bert_config):
        super(bert_base, self).__init__(config)
        self.config = bert_config
        self.embeddings = BertEmbeddings(bert_config)

        bert_layer = BertLayer(config, False, False)
        self.config = config
        self.bert_config = bert_config

        self.bert_layers = nn.ModuleList([copy.deepcopy(bert_layer) for _ in range(config.num_hidden_layers)])

        self.output1 = nn.Linear(bert_config.hidden_size, 1)
        self.output2 = nn.Linear(bert_config.hidden_size, 1)

    def forward(self, input_ids, attention_mask, seg_ids):

        input_embeddings = self.embeddings(input_ids, token_type_ids=seg_ids)

        attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
        extend_mask = (1.0 - attention_mask) * -10000.0
        extend_mask = extend_mask.to(input_ids.dtype)


        hidden_state = input_embeddings
        for layer in self.bert_layers:
            hidden_state = layer(hidden_state, extend_mask)
        
        start_pos = self.output1(hidden_state).squeeze(2)
        end_pos = self.output2(hidden_state).squeeze(2)

        return start_pos, end_pos





