import torch.nn.functional as F
from torch import nn
import torch


class MASKBertModel(nn.Module):
    def __init__(self, bert,tokenizer,device):
        super(MASKBertModel, self).__init__()
        self.bert = bert.to(device)
        self.tokenizer = tokenizer
        self.mask_id = tokenizer.token_to_id('[MASK]')


    def forward(self, inputs_embeds,input_ids):
        output = self.bert(inputs_embeds =inputs_embeds)
        mask_positions_list = [(id == self.mask_id).nonzero().squeeze() for id in input_ids]
        mask_hidden_state = [hidden_state[mask_positions] for hidden_state, mask_positions in
                             zip(output['last_hidden_state'], mask_positions_list)]
        mask_hidden_state = torch.stack(mask_hidden_state)
        return mask_hidden_state
class MASKBertTextModel(nn.Module):
    def __init__(self, bert,tokenizer,device):
        super(MASKBertTextModel, self).__init__()
        self.bert = bert.to(device)
        self.tokenizer = tokenizer

    def forward(self, input_ids, token_type_ids, attention_mask):
        output = self.bert(input_ids, token_type_ids, attention_mask)
        mean_pool = output.pooler_output
        # cls_hidden_state = torch.stack(cls_hidden_state)
        return mean_pool
class MLP(nn.Module):
    def __init__(self, input_size,hidden_size,output_size,num_label):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, output_size)
        self.classifier = nn.Linear(output_size, num_label)



    def forward(self, semantic_x,x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = x + semantic_x
        x = self.classifier(x)
        return x
#
class Classifier(nn.Module):
    def __init__(self, hidden_size, num_label):
        super(Classifier, self).__init__()
        self.linear = nn.Linear(hidden_size, num_label)

    def forward(self, x):
        x = self.linear(x)
        return x
class AttentionFusionModule(nn.Module):
    def __init__(self,input_size):
        super(AttentionFusionModule,self).__init__()
        self.linear_q = nn.Linear(input_size,input_size)
        self.linear_k = nn.Linear(input_size,input_size)
        self.linear_v = nn.Linear(input_size,input_size)
        self.d = input_size**0.5

    def forward(self,input1,input2):
        q1 = self.linear_q(input1)
        k1 = self.linear_k(input1)
        v1 = self.linear_v(input1)

        q2 = self.linear_q(input2)
        k2 = self.linear_k(input2)
        v2 = self.linear_v(input2)

        attention_scores1 = F.softmax(torch.bmm(q1.unsqueeze(2),k1.unsqueeze(1))/self.d,dim=-1)
        attention_scores2 = F.softmax(torch.bmm(q2.unsqueeze(2),k2.unsqueeze(1))/self.d,dim=-1)

        fusion_output = torch.bmm(attention_scores1,v1.unsqueeze(2)).squeeze() + torch.bmm(attention_scores2,v2.unsqueeze(2)).squeeze()
        return fusion_output


class MainModel(nn.Module):
    def __init__(self,bert,transE,mlp):
        super(MainModel, self).__init__()
        self.bert = bert
        self.transE = transE
        self.mlp = mlp

    def forward(self,input_embeddings, input_ids,heads,relations,labels):
        hidden_state = self.bert(inputs_embeds =input_embeddings,input_ids = input_ids)
        head_embeddings,relation_embeddings,tail_embeddings = self.transE(heads,relations,labels)
        fusion_hidden_state = torch.cat((hidden_state,hidden_state-(head_embeddings+relation_embeddings),hidden_state*(head_embeddings+relation_embeddings),head_embeddings+relation_embeddings),dim=-1)
        output = self.mlp(hidden_state,fusion_hidden_state)
        structure_loss = self.transE.InfoNCE_loss(heads,relations,labels)
        return output,structure_loss

class MainfusionModel(nn.Module):
    def __init__(self,bert,transE,classifier,fusion_module):
        super(MainfusionModel, self).__init__()
        self.bert = bert
        self.transE = transE
        self.classifier = classifier
        self.fusion_module = fusion_module

    def forward(self,input_embeddings, input_ids,heads,relations,labels):
        hidden_state = self.bert(inputs_embeds =input_embeddings,input_ids = input_ids)
        head_embeddings,relation_embeddings,tail_embeddings = self.transE(heads,relations,labels)
        fusion_output = self.fusion_module(hidden_state,head_embeddings+relation_embeddings)
        output = self.classifier(fusion_output)
        structure_loss = self.transE.InfoNCE_loss(heads,relations,labels)
        return output,structure_loss
class simpleMLP(nn.Module):
    def __init__(self,input_size,output_size):
        super(simpleMLP,self).__init__()
        self.fc = nn.Linear(input_size,output_size)
    def forward(self,x):
        x = self.fc(x)
        return x
class MASKBertFusionModel(nn.Module):
    def __init__(self, bert,embedding_model,simpleMLP,tokenizer,classifier,device):
        super(MASKBertFusionModel, self).__init__()
        self.bert = bert.to(device)
        self.tokenizer = tokenizer
        self.mask_id = tokenizer.token_to_id('[MASK]')
        self.embedding_model = embedding_model
        self.classifier = classifier
        self.simplemlp = simpleMLP
        self.crit = nn.NLLLoss()
        self.loss_weight = 0.5


    def forward(self, inputs_embeds,input_ids,heads,relations,tails):
        head_embeddings,relation_embeddings,tail_embeddings = self.embedding_model(heads,relations,tails)
        #pre fusion
        semantic_head_embedding = inputs_embeds[:,1,:].clone()
        semantic_relation_embedding = inputs_embeds[:,2,:].clone()
        fusion_head = torch.cat((semantic_head_embedding,head_embeddings),dim=-1)
        fusion_relation = torch.cat((semantic_relation_embedding,relation_embeddings),dim=-1)
        fusion_head_output = self.simplemlp(fusion_head)
        fusion_relation_output = self.simplemlp(fusion_relation)
        inputs_embeds[:, 1, :] = fusion_head_output
        inputs_embeds[:, 2, :] = fusion_relation_output
        # inputs_embeds[:,1,:] = head_embeddings + inputs_embeds[:,1,:]
        # inputs_embeds[:,2,:] = relation_embeddings + inputs_embeds[:,2,:]

        output = self.bert(inputs_embeds =inputs_embeds)
        mask_positions_list = [(id == self.mask_id).nonzero().squeeze() for id in input_ids]
        mask_hidden_state = [hidden_state[mask_positions] for hidden_state, mask_positions in
                             zip(output['last_hidden_state'], mask_positions_list)]
        mask_hidden_state = torch.stack(mask_hidden_state)
        # structure_loss = self.embedding_model.get_loss(heads,relations,tails)
        logit = self.classifier(mask_hidden_state)
        structure_loss = self.embedding_model.get_loss(heads,relations,tails)
        # transE_logits = self.embedding_model.get_output(heads,relations)
        # semantic_softmax_value = torch.softmax(logit,dim=-1)
        # structure_softmax_value = torch.softmax(transE_logits,dim=-1)
        # total_softmax_value = self.loss_weight * semantic_softmax_value + (1 - self.loss_weight) * structure_softmax_value
        # total_log_softmax_value = torch.log(total_softmax_value)
        # loss = self.crit(total_log_softmax_value,tails)


        # return logit,structure_loss
        # return loss,total_softmax_value
        return logit,structure_loss

class NewMASKBertFusionModel(nn.Module):
    def __init__(self, bert,transE,simpleMLP,tokenizer,classifier,device):
        super(NewMASKBertFusionModel, self).__init__()
        self.bert = bert.to(device)
        self.tokenizer = tokenizer
        self.mask_id = tokenizer.token_to_id('[MASK]')
        self.transE = transE
        self.classifier = classifier
        self.simplemlp = simpleMLP


    def forward(self, inputs_embeds,input_ids,heads,relations,tails):
        head_embeddings,relation_embeddings,tail_embeddings = self.transE(heads,relations,tails)
        #pre fusion
        semantic_head_embedding = inputs_embeds[:,1,:].clone()
        semantic_relation_embedding = inputs_embeds[:,2,:].clone()
        fusion_head = torch.cat((semantic_head_embedding,head_embeddings),dim=-1)
        fusion_relation = torch.cat((semantic_relation_embedding,relation_embeddings),dim=-1)
        fusion_head_output = self.simplemlp(fusion_head)
        fusion_relation_output = self.simplemlp(fusion_relation)
        inputs_embeds[:, 1, :] = fusion_head_output
        inputs_embeds[:, 2, :] = fusion_relation_output
        # inputs_embeds[:,1,:] = head_embeddings + inputs_embeds[:,1,:]
        # inputs_embeds[:,2,:] = relation_embeddings + inputs_embeds[:,2,:]

        output = self.bert(inputs_embeds =inputs_embeds)
        mask_positions_list = [(id == self.mask_id).nonzero().squeeze() for id in input_ids]
        mask_hidden_state = [hidden_state[mask_positions] for hidden_state, mask_positions in
                             zip(output['last_hidden_state'], mask_positions_list)]
        mask_hidden_state = torch.stack(mask_hidden_state)
        structure_loss = self.transE.InfoNCE_loss(heads,relations,tails)
        logit = self.classifier(mask_hidden_state)
        # transE_logits = self.transE.get_output(heads,relations)
        # print(transE_logits)

        return logit,structure_loss