import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
from transformers import BertTokenizer, BertModel, AutoModel
from model.transformer import TransformerModel
from model.transformer_new import Transformer
from model.CHAN import ContextAttention
from model.torchcrf import CRF
from model.mia import MutualIterativeAttention

class BertContextNLU(nn.Module):
    
    def __init__(self, config, opt, num_labels=2, num_slot_labels=144):
        super(BertContextNLU, self).__init__()
        self.opt = opt
        self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
        self.model_mode = opt.model

        self.num_labels = num_labels
        self.num_slot_labels = num_slot_labels

        self.bert = BertModel.from_pretrained('bert-base-uncased', output_hidden_states=True, output_attentions=True)
        # self.tod_bert = AutoModel.from_pretrained("TODBERT/TOD-BERT-JNT-V1", output_hidden_states=True, output_attentions=True)

        self.dropout = nn.Dropout(0.1)
        self.hidden_size = config.hidden_size
        self.rnn_hidden = opt.rnn_hidden

        #########################################

        # naive bert
        self.classifier = nn.Linear(config.hidden_size, num_labels)
        nn.init.xavier_normal_(self.classifier.weight)
        self.slot_classifier_naive = nn.Linear(config.hidden_size, num_slot_labels)
        nn.init.xavier_normal_(self.classifier.weight)

        # self attentive
        self.linear1 = nn.Linear(config.hidden_size, 256)
        self.linear2 = nn.Linear(4*256, config.hidden_size)
        self.tanh = nn.Tanh()
        self.context_vector = nn.Parameter(torch.randn(256, 4), requires_grad=True)

        # transformer
        self.transformer_model = TransformerModel(ninp=self.hidden_size, nhead=4, nhid=64, nlayers=2, dropout=0.1)
        self.transformer_encoder = Transformer(hidden_dim=self.hidden_size, 
                                               model_dim=256,
                                               num_heads=2, 
                                               dropout=0.1)
        
        # DiSAN
        self.conv1 = nn.Conv1d(self.hidden_size, self.hidden_size, 3, padding=1)
        self.conv2 = nn.Conv1d(self.hidden_size, self.hidden_size, 3, padding=1)
        self.fc1 = nn.Linear(2*self.hidden_size, self.rnn_hidden)

        # CHAN
        self.context_encoder = ContextAttention(self.device)

        # rnn
        self.rnn = nn.LSTM(input_size=self.hidden_size, 
                           hidden_size=self.rnn_hidden,
                           batch_first=True,
                           num_layers=1)
        
        # classifier
        self.classifier_rnn = nn.Linear(self.rnn_hidden, num_labels)
        nn.init.xavier_normal_(self.classifier_rnn.weight)
        self.classifier_bert = nn.Linear(self.hidden_size, num_labels)
        nn.init.xavier_normal_(self.classifier_bert.weight)
        self.classifier_transformer = nn.Linear(self.rnn_hidden*4, num_labels)
        nn.init.xavier_normal_(self.classifier_transformer.weight)

        # self.crf = CRF(self.num_slot_labels)

        ##############################################################################
        # knowledge
        print('Model: ', self.model_mode)

        # baseline
        if self.model_mode == 'baseline':
            self.define_baseline(num_slot_labels)
        elif self.model_mode == 'baseline_attention':
            self.define_baseline_attention(num_slot_labels)
        elif self.model_mode == 'slot_embedding':
            self.define_slot_embedding(num_slot_labels)
        elif self.model_mode == 'kg_att_context_as_hidden':
            self.define_kg_att_context_as_hidden(num_slot_labels)
        elif self.model_mode == 'kg_att_context_as_input':
            self.define_kg_att_context_as_input(num_slot_labels)
        elif self.model_mode == 'kg_att2_context_as_hidden':
            self.define_kg_att2_context_as_hidden(num_slot_labels)
        elif self.model_mode == 'kg_att2_context_as_hidden_slot':
            self.define_kg_att2_context_as_hidden_slot(num_slot_labels)
        elif self.model_mode == 'kg_att2_context_as_hidden_gating':
            self.define_kg_att2_context_as_hidden_gating(num_slot_labels)
        elif self.model_mode == 'kg_att2_context_as_hidden_lka':
            self.define_kg_att2_context_as_hidden_lka(num_slot_labels)
        elif self.model_mode == 'kg_att2_context_global':
            self.define_kg_att2_context_global(num_slot_labels)
        elif self.model_mode == 'kg_att2_context_local':
            self.define_kg_att2_context_local(num_slot_labels)
        elif self.model_mode == 'glka_trans':
            self.define_glka_trans(num_slot_labels)
        
    
    def self_attentive(self, last_hidden_states, d, b):
        # input should be (b,d,h)
        vectors = self.context_vector.unsqueeze(0).repeat(b*d, 1, 1)

        h = self.linear1(last_hidden_states) # (b*d, t, h)
        scores = torch.bmm(h, vectors) # (b*d, t, 4)
        scores = nn.Softmax(dim=1)(scores) # (b*d, t, 4)
        outputs = torch.bmm(scores.permute(0, 2, 1), h).view(b*d, -1) # (b*d, 4h)
        pooled_output = self.linear2(outputs) # (b*d, h)

        pooled_output = pooled_output.view(b,d,self.hidden_size) # (b,d,h)
        return pooled_output
    
    def mha(self, pooled_output, d, b):
        # input should be (d,b,h)
        pooled_output = pooled_output.view(d,b,self.hidden_size)
        # src_mask = self.transformer_model.generate_square_subsequent_mask(d).to(self.device)
        pooled_output = self.transformer_model(pooled_output, src_mask=None)
        pooled_output = pooled_output.view(b,d,self.hidden_size)
        return pooled_output
    
    def label_embed(self, y_caps, y_masks, rnn_out, d, b):
        last_hidden, clusters, hidden, att = self.bert(y_caps, attention_mask=y_masks)
        # clusters = self.mapping(clusters) # (n, 256)

        gram = torch.mm(clusters, clusters.permute(1,0)) # (n, n)
        rnn_out = rnn_out.reshape(b*d, self.hidden_size) # (b*d, 768)
        weights = torch.mm(rnn_out, clusters.permute(1,0)) # (b*d, n)
        logits = torch.mm(weights, torch.inverse(gram))
        logits = logits.view(b,d,self.num_labels)

        return logits
    
    def DiSAN(self, pooled_output, d, b):
        # input should be (b,h,d)
        pooled_score = pooled_output.view(b,self.hidden_size,d)
        pooled_score = torch.sigmoid(self.conv1(pooled_score))
        pooled_score = self.conv2(pooled_score)
        pooled_score = F.softmax(pooled_score, dim=-1)
        pooled_score = pooled_score.view(b,d,self.hidden_size)
        pooled_output = pooled_score * pooled_output
        return pooled_output


    def forward(self, result_ids, result_token_masks, result_masks, lengths, result_slot_labels, labels, 
                y_caps, y_masks, slot_caps, slot_masks, result_kg):
        """
        Inputs:
        result_ids:         (b, d, t)
        result_token_masks: (b, d, t)
        result_masks:       (b, d)
        lengths:            (b)
        result_slot_labels: (b, d, t)
        labels:             (b, d, l)

        BERT outputs:
        last_hidden_states: (bxd, t, h)
        pooled_output: (bxd, h), from output of a linear classifier + tanh
        hidden_states: 13 x (bxd, t, h), embed to last layer embedding
        attentions: 12 x (bxd, num_heads, t, t)
        """

        ############### 1. Token-level BERT encoding ###############
        b,d,t = result_ids.shape
        result_ids = result_ids.view(-1, t)
        result_token_masks = result_token_masks.view(-1, t)
        last_hidden_states, pooled_output, hidden_states, attentions = self.bert(result_ids, attention_mask=result_token_masks)
        pooled_output = pooled_output.view(b,d,self.hidden_size)

        ## Token: Self-attentive
        # pooled_output = self.self_attentive(last_hidden_states, d, b) # (b,d,l)
        # logits = self.classifier_bert(pooled_output)

        if self.opt.run_baseline == 'bert_naive':
            pooled_output_d = self.dropout(pooled_output)
            logits = self.classifier(pooled_output_d)

            # Remove padding
            logits_no_pad = []
            labels_no_pad = []
            for i in range(b):
                logits_no_pad.append(logits[i,:lengths[i],:])
                labels_no_pad.append(labels[i,:lengths[i],:])
            logits = torch.cat(logits_no_pad, dim=0)
            labels = torch.cat(labels_no_pad, dim=0)

            slot_hidden = last_hidden_states[:, 1:, :]
            slot_out = self.slot_classifier_naive(slot_hidden)
            slot_out = slot_out.view(-1, self.num_slot_labels)

            return logits, labels, slot_out
        elif self.opt.run_baseline == 'laban':
            pooled_output = pooled_output.view(b*d,self.hidden_size)
            last_hidden, clusters, hidden, att = self.bert(y_caps, attention_mask=y_masks)
            gram = torch.mm(clusters, clusters.permute(1,0)) # (n, n)
            weights = torch.mm(pooled_output, clusters.permute(1,0))
            weights = torch.mm(weights, torch.inverse(gram)) * np.sqrt(768)
            logits = weights.view(b, d, -1)

            # Remove padding
            logits_no_pad = []
            labels_no_pad = []
            for i in range(b):
                logits_no_pad.append(logits[i,:lengths[i],:])
                labels_no_pad.append(labels[i,:lengths[i],:])
            logits = torch.cat(logits_no_pad, dim=0)
            labels = torch.cat(labels_no_pad, dim=0)

            slot_hidden = last_hidden_states[:, 1:, :]
            slot_out = self.slot_classifier_naive(slot_hidden)
            slot_out = slot_out.view(-1, self.num_slot_labels)

            return logits, labels, slot_out

        
        ############### 2. Turn-level Context encoding ###############
        ## Turn: MHA
        # pooled_output = self.mha(pooled_output, d, b) # (b,d,l)

        ## Turn: DiSAN
        if self.opt.run_baseline == 'casa':
            context_vector = self.DiSAN(pooled_output, d, b)
            final_hidden = torch.cat([pooled_output, context_vector], dim=-1)
            final_hidden = self.fc1(final_hidden)
            logits = self.classifier_rnn(final_hidden)

            # Remove padding
            logits_no_pad = []
            labels_no_pad = []
            for i in range(b):
                logits_no_pad.append(logits[i,:lengths[i],:])
                labels_no_pad.append(labels[i,:lengths[i],:])
            logits = torch.cat(logits_no_pad, dim=0)
            labels = torch.cat(labels_no_pad, dim=0) 

            slot_vectors = last_hidden_states # (b*d,t,h)
            intent_context = final_hidden.unsqueeze(2).repeat(1,1,t,1).reshape(-1,t,self.rnn_hidden) # (b*d,t,hr)
            slot_inputs = torch.cat([slot_vectors, intent_context], dim=-1) # (b*d,t,h+hr)
            slot_out = self.baseline(slot_inputs)

            return logits, labels, slot_out

        ## Turn: CHAN
        pooled_output, ffscores = self.context_encoder(pooled_output, result_masks)
        # logits = self.classifier_bert(pooled_output) # (b,d,l)

        ## Turn: transformer
        # transformer_out, attention = self.transformer_encoder(pooled_output, pooled_output, pooled_output, result_masks)
        # transformer_out = self.dropout(transformer_out)
        # logits = self.classifier_transformer(transformer_out) # (b,d,l)

        if self.model_mode == 'kg_att2_context_as_hidden_lka':
            slot_vectors = last_hidden_states
            return self.kg_att2_context_as_hidden_lka(b, d, t, result_kg, slot_vectors, pooled_output, lengths, labels)

        if self.model_mode == 'kg_att2_context_global':
            slot_vectors = last_hidden_states
            return self.kg_att2_context_global(b, d, t, result_kg, slot_vectors, pooled_output, lengths, labels)
        
        if self.model_mode == 'kg_att2_context_local':
            slot_vectors = last_hidden_states
            return self.kg_att2_context_local(b, d, t, result_kg, slot_vectors, pooled_output, lengths, labels)
        
        if self.model_mode == 'glka_trans':
            slot_vectors = last_hidden_states
            return self.glka_trans(b, d, t, result_kg, slot_vectors, pooled_output, lengths, labels)

        ############### 3. Intent: RNN prediction ###############
        ## Prediction: RNN
        rnn_out, _ = self.rnn(pooled_output)
        rnn_out = self.dropout(rnn_out)
        logits = self.classifier_rnn(rnn_out) # (b,d,l)

        ## Prediction: Label Embedding
        # logits = self.label_embed(y_caps, y_masks, pooled_output, d, b)

        # Remove padding
        logits_no_pad = []
        labels_no_pad = []
        for i in range(b):
            logits_no_pad.append(logits[i,:lengths[i],:])
            labels_no_pad.append(labels[i,:lengths[i],:])
        logits = torch.cat(logits_no_pad, dim=0)
        labels = torch.cat(labels_no_pad, dim=0)   

        ############### 4. Slot: prediction ###############

        if self.model_mode == 'baseline':
            # baseline
            slot_vectors = last_hidden_states # (b*d,t,h)
            intent_context = rnn_out.unsqueeze(2).repeat(1,1,t,1).reshape(-1,t,self.rnn_hidden) # (b*d,t,hr)
            slot_inputs = torch.cat([slot_vectors, intent_context], dim=-1) # (b*d,t,h+hr)
            slot_out = self.baseline(slot_inputs)
        elif self.model_mode == 'baseline_attention':
            # baseline attention
            slot_vectors = last_hidden_states # (b*d,t,h)
            intent_context = rnn_out.unsqueeze(2).repeat(1,1,t,1).reshape(-1,t,self.rnn_hidden) # (b*d,t,hr)
            slot_inputs = torch.cat([slot_vectors, intent_context], dim=-1) # (b*d,t,h+hr)
            slot_out = self.baseline_attention(slot_inputs)
        elif self.model_mode == 'slot_embedding':
            slot_vectors = last_hidden_states # (b*d,t,h)
            intent_context = rnn_out.unsqueeze(2).repeat(1,1,t,1).reshape(-1,t,self.rnn_hidden) # (b*d,t,hr)
            slot_inputs = torch.cat([slot_vectors, intent_context], dim=-1) # (b*d,t,h+hr)
            slot_out = self.slot_embedding(b, d, t, slot_caps, slot_masks, slot_inputs)
        elif self.model_mode == 'kg_att_context_as_hidden':
            # attention 1 context as hidden
            slot_vectors = last_hidden_states # (b*d,t,h)
            slot_out = self.kg_att_context_as_hidden(b, d, t, result_kg, slot_vectors, rnn_out)
        elif self.model_mode == 'kg_att_context_as_input':
            # attention 1 context as input
            slot_vectors = last_hidden_states # (b*d,t,h)
            intent_context = rnn_out.unsqueeze(2).repeat(1,1,t,1).reshape(-1,t,self.rnn_hidden) # (b*d,t,hr)
            slot_inputs = torch.cat([slot_vectors, intent_context], dim=-1) # (b*d,t,h+hr)
            slot_out = self.kg_att_context_as_input(b, d, t, result_kg, slot_inputs)
        elif self.model_mode == 'kg_att2_context_as_hidden':
            # attention 2 context as hidden
            slot_vectors = last_hidden_states # (b*d,t,h)
            slot_out = self.kg_att2_context_as_hidden(b, d, t, result_kg, slot_vectors, rnn_out)
        elif self.model_mode == 'kg_att2_context_as_hidden_slot':
            # attention 2 context as hidden
            slot_vectors = last_hidden_states # (b*d,t,h)
            slot_out = self.kg_att2_context_as_hidden_slot(b, d, t, result_kg, slot_vectors, rnn_out, slot_caps, slot_masks)
        elif self.model_mode == 'kg_att2_context_as_hidden_gating':
            # attention 2 context as hidden
            slot_vectors = last_hidden_states # (b*d,t,h)
            slot_out, score_kg = self.kg_att2_context_as_hidden_gating(b, d, t, result_kg, slot_vectors, pooled_output)

        # slot_loss = -self.crf(slot_out, result_slot_labels)

        # slot_rnn_out = slot_rnn_out[:,-1,:].reshape(b,d,-1)
        # logits = self.classifier_rnn(slot_rnn_out)

        # logits_no_pad = []
        # labels_no_pad = []
        # for i in range(b):
        #     logits_no_pad.append(logits[i,:lengths[i],:])
        #     labels_no_pad.append(labels[i,:lengths[i],:])
        # logits = torch.cat(logits_no_pad, dim=0)
        # labels = torch.cat(labels_no_pad, dim=0)

        return logits, labels, slot_out#, score_kg
    
    
    
    ################################################################################
    ############################# function for testing #############################
    ################################################################################
    # baseline
    def define_baseline(self, num_slot_labels):
        self.slot_rnn = nn.LSTM(input_size=self.hidden_size+self.rnn_hidden,
                                hidden_size=self.rnn_hidden,
                                batch_first=True,
                                bidirectional=True,
                                num_layers=1)
        self.slot_classifier = nn.Linear(2*self.rnn_hidden, num_slot_labels)
    
    def baseline(self, slot_inputs):
        slot_rnn_out, _ = self.slot_rnn(slot_inputs)
        slot_rnn_out = self.dropout(slot_rnn_out)
        slot_out = self.slot_classifier(slot_rnn_out)
        slot_out = slot_out.view(-1, self.num_slot_labels) # (b*d*t, num_slots)
        return slot_out
    
    # baseline attention
    def define_baseline_attention(self, num_slot_labels):
        self.mia_encoder = MutualIterativeAttention(self.device)
        self.slot_classifier = nn.Linear(self.hidden_size+self.rnn_hidden, num_slot_labels)
    
    def baseline_attention(self, slot_inputs):
        slot_out = self.mia_encoder(slot_inputs, slot_inputs) # (b*d*t, h+hr)
        slot_out = self.slot_classifier(slot_out)
        slot_out = slot_out.view(-1, self.num_slot_labels) # (b*d*t, num_slots)
        return slot_out
    
    # slot
    def define_slot_embedding(self, num_slot_labels):
        self.bertslot = BertModel.from_pretrained('bert-base-uncased', output_hidden_states=True, output_attentions=True)
        self.slot_rnn = nn.LSTM(input_size=self.hidden_size+self.rnn_hidden,
                                hidden_size=self.rnn_hidden,
                                batch_first=True,
                                bidirectional=True,
                                num_layers=1)
        self.linear_slot = nn.Linear(self.hidden_size, 2*self.rnn_hidden)
        self.slot_classifier = nn.Linear(self.hidden_size+self.rnn_hidden, num_slot_labels)
    
    def slot_embedding(self, b, d, t, slot_caps, slot_masks, slot_inputs):
        last_hidden, slot_embeds, hidden, att = self.bertslot(slot_caps, attention_mask=slot_masks) # (k, h)
        out1 = self.linear_slot(slot_embeds) # (k, 2*hr)

        slot_rnn_out, _ = self.slot_rnn(slot_inputs)
        slot_rnn_out = slot_rnn_out.reshape(-1, self.rnn_hidden*2)
        slot_out = torch.mm(slot_rnn_out, out1.transpose(1,0)) # (b*d*t, num_slots)

        return slot_out
    
    # kg_att_context_as_hidden
    def define_kg_att_context_as_hidden(self, num_slot_labels):
        self.slot_rnn = nn.LSTM(input_size=2*self.hidden_size,
                                hidden_size=self.rnn_hidden,
                                batch_first=True,
                                bidirectional=True,
                                num_layers=1)
        self.slot_classifier = nn.Linear(2*self.rnn_hidden, num_slot_labels)

        self.k_linear1 = nn.Linear(300, self.hidden_size)
        self.att1 = nn.Linear(self.hidden_size, 200)
        self.att2 = nn.Linear(self.hidden_size, 200)
        self.score = nn.Linear(200, 1)
    
    def kg_att_context_as_hidden(self, b, d, t, kg_input, slot_input, context_input):
        
        # knowledge attention
        kg_input = kg_input.reshape(-1, 5, 300)
        v_kg = self.tanh(self.k_linear1(kg_input)) # (b*d*t, 5, h)

        slot_inputs = slot_input.reshape(-1, self.hidden_size)
        context = self.att1(slot_inputs).unsqueeze(1).repeat(1, 5, 1)
        knowledge = self.att2(v_kg)

        score_kg = self.score(self.tanh(context + knowledge)) # (b*d*t, 5, 1)
        score_kg = nn.Softmax(dim=1)(score_kg) # (b*d*t, 5, 1)
        fused_v_kg = torch.bmm(score_kg.permute(0,2,1), v_kg).squeeze(1) # (b*d*t, h)

        # rnn
        rnn_inputs = torch.cat([slot_inputs, fused_v_kg], dim=-1)
        rnn_inputs = rnn_inputs.reshape(b*d, t, -1)

        context_input = context_input.reshape(-1, self.rnn_hidden).unsqueeze(0).repeat(2,1,1)
        (h_0, c_0) = context_input, torch.zeros(*context_input.shape).to(self.device)

        slot_rnn_out, _ = self.slot_rnn(rnn_inputs, (h_0, c_0))
        slot_rnn_out = self.dropout(slot_rnn_out)
        slot_out = self.slot_classifier(slot_rnn_out)
        slot_out = slot_out.view(-1, self.num_slot_labels) # (b*d*t, num_slots)

        return slot_out
    
    # kg_att_context_as_input
    def define_kg_att_context_as_input(self, num_slot_labels):
        self.slot_rnn = nn.LSTM(input_size=2*(self.hidden_size+self.rnn_hidden),
                                hidden_size=self.rnn_hidden,
                                batch_first=True,
                                bidirectional=True,
                                num_layers=1)
        self.slot_classifier = nn.Linear(2*self.rnn_hidden, num_slot_labels)

        self.k_linear1 = nn.Linear(300, self.hidden_size+self.rnn_hidden)
        self.att1 = nn.Linear(self.hidden_size+self.rnn_hidden, 200)
        self.att2 = nn.Linear(self.hidden_size+self.rnn_hidden, 200)
        self.score = nn.Linear(200, 1)
    
    def kg_att_context_as_input(self, b, d, t, kg_input, slot_input):

        # knowledge attention
        kg_input = kg_input.reshape(-1, 5, 300)
        v_kg = self.tanh(self.k_linear1(kg_input)) # (b*d*t, 5, h+hr)

        slot_inputs = slot_input.reshape(-1, self.hidden_size+self.rnn_hidden)
        context = self.att1(slot_inputs).unsqueeze(1).repeat(1, 5, 1)
        knowledge = self.att2(v_kg)

        score_kg = self.score(self.tanh(context + knowledge)) # (b*d*t, 5, 1)
        score_kg = nn.Softmax(dim=1)(score_kg) # (b*d*t, 5, 1)
        fused_v_kg = torch.bmm(score_kg.permute(0,2,1), v_kg).squeeze(1) # (b*d*t, h+hr)

        # rnn
        rnn_inputs = torch.cat([slot_inputs, fused_v_kg], dim=-1)
        rnn_inputs = rnn_inputs.reshape(b*d, t, -1)

        slot_rnn_out, _ = self.slot_rnn(rnn_inputs)
        slot_rnn_out = self.dropout(slot_rnn_out)
        slot_out = self.slot_classifier(slot_rnn_out)
        slot_out = slot_out.view(-1, self.num_slot_labels) # (b*d*t, num_slots)

        return slot_out
    
    # kg_att2_context_as_hidden
    def define_kg_att2_context_as_hidden(self, num_slot_labels):
        self.slot_rnn = nn.LSTM(input_size=self.hidden_size+200,
                                hidden_size=self.rnn_hidden,
                                batch_first=True,
                                bidirectional=True,
                                num_layers=1)
        self.slot_classifier = nn.Linear(2*self.rnn_hidden, num_slot_labels)

        self.rel_linear = nn.Linear(100, 200)
        self.tail_linear = nn.Linear(100, 200)
        self.context_linear = nn.Linear(self.hidden_size, 200)
    
    def kg_att2_context_as_hidden(self, b, d, t, kg_input, slot_input, context_input):
        kg_input = kg_input.reshape(-1, 5, 300)
        rel = kg_input[:,:,100:200]
        tail = kg_input[:,:,200:]

        # knowledge attention
        slot_inputs = slot_input.reshape(-1, self.hidden_size)
        context = self.context_linear(slot_inputs).unsqueeze(1)
        out1 = self.tanh(self.rel_linear(rel)+self.tail_linear(tail))
        score_kg = torch.bmm(out1, context.permute(0,2,1)) # (b*d*t, 5, 1)
        score_kg = nn.Softmax(dim=1)(score_kg) # (b*d*t, 5, 1)
        fused_v_kg = torch.bmm(score_kg.permute(0,2,1), torch.cat((rel, tail), dim=2)).squeeze(1) # (b*d*t, h)

        # rnn
        rnn_inputs = torch.cat([slot_inputs, fused_v_kg], dim=-1)
        rnn_inputs = rnn_inputs.reshape(b*d, t, -1)

        context_input = context_input.reshape(-1, self.rnn_hidden).unsqueeze(0).repeat(2,1,1)
        (h_0, c_0) = context_input, torch.zeros(*context_input.shape).to(self.device)

        slot_rnn_out, _ = self.slot_rnn(rnn_inputs, (h_0, c_0))
        slot_rnn_out = self.dropout(slot_rnn_out)
        slot_out = self.slot_classifier(slot_rnn_out)
        slot_out = slot_out.view(-1, self.num_slot_labels) # (b*d*t, num_slots)

        return slot_out
    
    # kg_att2_context_as_hidden_slot
    def define_kg_att2_context_as_hidden_slot(self, num_slot_labels):
        self.bertslot = BertModel.from_pretrained('bert-base-uncased', output_hidden_states=True, output_attentions=True)
        self.linear_slot = nn.Linear(self.hidden_size, 2*self.rnn_hidden)
        
        self.slot_rnn = nn.LSTM(input_size=self.hidden_size+200,
                                hidden_size=self.rnn_hidden,
                                batch_first=True,
                                bidirectional=True,
                                num_layers=1)
        self.slot_classifier = nn.Linear(2*self.rnn_hidden, num_slot_labels)

        self.rel_linear = nn.Linear(100, 200)
        self.tail_linear = nn.Linear(100, 200)
        self.context_linear = nn.Linear(self.hidden_size, 200)
    
    def kg_att2_context_as_hidden_slot(self, b, d, t, kg_input, slot_input, context_input, slot_caps, slot_masks):
        kg_input = kg_input.reshape(-1, 5, 300)
        rel = kg_input[:,:,100:200]
        tail = kg_input[:,:,200:]

        # knowledge attention
        slot_inputs = slot_input.reshape(-1, self.hidden_size)
        context = self.context_linear(slot_inputs).unsqueeze(1)
        out1 = self.tanh(self.rel_linear(rel)+self.tail_linear(tail))
        score_kg = torch.bmm(out1, context.permute(0,2,1)) # (b*d*t, 5, 1)
        score_kg = nn.Softmax(dim=1)(score_kg) # (b*d*t, 5, 1)
        fused_v_kg = torch.bmm(score_kg.permute(0,2,1), torch.cat((rel, tail), dim=2)).squeeze(1) # (b*d*t, h)

        # slot embeddings
        last_hidden, slot_embeds, hidden, att = self.bertslot(slot_caps, attention_mask=slot_masks) # (k, h)
        slot_embeddings = self.linear_slot(slot_embeds) # (k, 2*hr)

        # rnn
        rnn_inputs = torch.cat([slot_inputs, fused_v_kg], dim=-1)
        rnn_inputs = rnn_inputs.reshape(b*d, t, -1)

        context_input = context_input.reshape(-1, self.rnn_hidden).unsqueeze(0).repeat(2,1,1)
        (h_0, c_0) = context_input, torch.zeros(*context_input.shape).to(self.device)

        slot_rnn_out, _ = self.slot_rnn(rnn_inputs, (h_0, c_0))
        slot_rnn_out = self.dropout(slot_rnn_out)
        slot_rnn_out = slot_rnn_out.reshape(-1, self.rnn_hidden*2)
        slot_out = torch.mm(slot_rnn_out, slot_embeddings.transpose(1,0)) # (b*d*t, num_slots)

        return slot_out
    
    # kg_att2_context_as_hidden_gating, KABEM
    def define_kg_att2_context_as_hidden_gating(self, num_slot_labels):
        self.kg_hidden_size = 200
        self.slot_rnn = nn.LSTM(input_size=200,
                                hidden_size=self.rnn_hidden,
                                batch_first=True,
                                bidirectional=True,
                                num_layers=1)
        self.slot_classifier = nn.Linear(2*self.rnn_hidden, num_slot_labels)

        self.rel_linear = nn.Linear(100, self.kg_hidden_size)
        self.tail_linear = nn.Linear(100, self.kg_hidden_size)
        self.context_linear = nn.Linear(2*self.hidden_size, self.kg_hidden_size)

        self.gating_linear = nn.Linear(self.kg_hidden_size*2, 1)
        self.context_dnn = nn.Linear(self.hidden_size, self.rnn_hidden)
    
    def kg_att2_context_as_hidden_gating(self, b, d, t, kg_input, slot_input, context_input):
        kg_input = kg_input.reshape(-1, 5, 300)
        rel = kg_input[:,:,100:200]
        tail = kg_input[:,:,200:]

        # knowledge attention
        # transform to the same space
        slot_inputs = slot_input.reshape(-1, self.hidden_size)
        context_x = context_input.reshape(-1, self.hidden_size).unsqueeze(1).repeat(1,t,1).reshape(-1, self.hidden_size)
        fused = torch.cat([context_x, slot_inputs], dim=-1)
        slot_inputs_transform = self.context_linear(fused).unsqueeze(1)
        rel_transform = self.rel_linear(rel)
        tail_transform = self.tail_linear(tail)

        out1 = self.tanh(rel_transform+tail_transform)
        score_kg = torch.bmm(out1, slot_inputs_transform.permute(0,2,1)) # (b*d*t, 5, 1)
        score_kg = nn.Softmax(dim=1)(score_kg) # (b*d*t, 5, 1)
        fused_v_kg = torch.bmm(score_kg.permute(0,2,1), rel_transform+tail_transform).squeeze(1) # (b*d*t, 300)

        # gating
        gate_inputs = torch.cat([slot_inputs_transform.squeeze(1), fused_v_kg], dim=-1) # (b*d*t, 300*2)
        g = nn.Sigmoid()(self.gating_linear(gate_inputs))
        rnn_inputs = g * slot_inputs_transform.squeeze(1) + (1-g) * fused_v_kg

        # rnn
        # rnn_inputs = torch.cat([slot_inputs, fused_v_kg], dim=-1)
        rnn_inputs = rnn_inputs.reshape(b*d, t, -1)

        context_input = self.context_dnn(context_input)
        context_input = context_input.reshape(-1, self.rnn_hidden).unsqueeze(0).repeat(2,1,1)
        (h_0, c_0) = context_input, torch.zeros(*context_input.shape).to(self.device)

        slot_rnn_out, _ = self.slot_rnn(rnn_inputs, (h_0, c_0))
        slot_rnn_out = self.dropout(slot_rnn_out)
        slot_out = self.slot_classifier(slot_rnn_out)
        slot_out = slot_out.view(-1, self.num_slot_labels) # (b*d*t, num_slots)

        return slot_out, score_kg.reshape(b,d,t,-1)
    
    # kg_att2_context_as_hidden_lka, lka
    def define_kg_att2_context_as_hidden_lka(self, num_slot_labels):
        self.kg_hidden_size = 200
        self.intent_rnn = nn.LSTM(input_size=self.kg_hidden_size*2,
                                hidden_size=self.rnn_hidden,
                                batch_first=True,
                                bidirectional=True,
                                num_layers=1)
        self.slot_rnn = nn.LSTM(input_size=self.hidden_size+200,
                                hidden_size=self.rnn_hidden,
                                batch_first=True,
                                bidirectional=True,
                                num_layers=1)
        self.intent_classifier = nn.Linear(2*self.rnn_hidden, self.num_labels)
        self.slot_classifier = nn.Linear(2*self.rnn_hidden, num_slot_labels)

        self.rel_linear = nn.Linear(100, self.kg_hidden_size)
        self.tail_linear = nn.Linear(100, self.kg_hidden_size)
        self.context_linear = nn.Linear(2*self.hidden_size, self.kg_hidden_size)

        self.gating_linear = nn.Linear(self.kg_hidden_size*2, 1)
        self.context_dnn = nn.Linear(self.hidden_size, self.rnn_hidden)
    
    def kg_att2_context_as_hidden_lka(self, b, d, t, kg_input, slot_input, context_input, lengths, labels):
        kg_input = kg_input.reshape(-1, 5, 300)
        rel = kg_input[:,:,100:200]
        tail = kg_input[:,:,200:]

        # knowledge attention
        # transform to the same space
        slot_inputs = slot_input.reshape(-1, self.hidden_size)
        context_x = context_input.reshape(-1, self.hidden_size).unsqueeze(1).repeat(1,t,1).reshape(-1, self.hidden_size)
        fused = torch.cat([context_x, slot_inputs], dim=-1)
        slot_inputs_transform = self.context_linear(fused).unsqueeze(1)
        rel_transform = self.rel_linear(rel)
        tail_transform = self.tail_linear(tail)

        out1 = self.tanh(rel_transform+tail_transform)
        score_kg = torch.bmm(out1, slot_inputs_transform.permute(0,2,1)) # (b*d*t, 5, 1)
        score_kg = nn.Softmax(dim=1)(score_kg) # (b*d*t, 5, 1)
        fused_v_kg = torch.bmm(score_kg.permute(0,2,1), rel_transform+tail_transform).squeeze(1) # (b*d*t, 300)

        # intent prediction
        slot_inputs_transform = slot_inputs_transform.reshape(-1, self.kg_hidden_size) # (b*d*t, 200)
        fused_intent = torch.cat([slot_inputs_transform, fused_v_kg], dim=1)
        fused_intent = fused_intent.reshape(b*d, t, -1).mean(dim=1).reshape(b, d, -1)
        rnn_out, _ = self.intent_rnn(fused_intent)
        rnn_out = self.dropout(rnn_out)
        logits = self.intent_classifier(rnn_out) # (b,d,l)

        # Remove padding
        logits_no_pad = []
        labels_no_pad = []
        for i in range(b):
            logits_no_pad.append(logits[i,:lengths[i],:])
            labels_no_pad.append(labels[i,:lengths[i],:])
        logits = torch.cat(logits_no_pad, dim=0)
        labels = torch.cat(labels_no_pad, dim=0)

        # slot prediction
        fused_v_kg = fused_v_kg.reshape(b*d, t, -1) # (b*d, t, 200)
        slot_input = slot_input.reshape(b*d, t, -1) # (b*d, t, h)
        fused_slot = torch.cat([fused_v_kg, slot_input], dim=2)

        context_input = self.context_dnn(context_input)
        context_input = context_input.reshape(-1, self.rnn_hidden).unsqueeze(0).repeat(2,1,1)
        (h_0, c_0) = context_input, torch.zeros(*context_input.shape).to(self.device)

        slot_rnn_out, _ = self.slot_rnn(fused_slot, (h_0, c_0))
        slot_rnn_out = self.dropout(slot_rnn_out)
        slot_out = self.slot_classifier(slot_rnn_out)
        slot_out = slot_out.view(-1, self.num_slot_labels) # (b*d*t, num_slots)

        return logits, labels, slot_out
    
    # kg_att2_context_global, gka
    def define_kg_att2_context_global(self, num_slot_labels):
        self.kg_hidden_size = 200
        self.intent_rnn = nn.LSTM(input_size=self.kg_hidden_size*2,
                                hidden_size=self.rnn_hidden,
                                batch_first=True,
                                bidirectional=True,
                                num_layers=1)
        self.slot_rnn = nn.LSTM(input_size=self.kg_hidden_size*2+768,
                                hidden_size=self.rnn_hidden,
                                batch_first=True,
                                bidirectional=True,
                                num_layers=1)
        self.intent_classifier = nn.Linear(2*self.rnn_hidden, self.num_labels)
        self.slot_classifier = nn.Linear(2*self.rnn_hidden, num_slot_labels)

        self.rel_linear = nn.Linear(100, self.kg_hidden_size)
        self.tail_linear = nn.Linear(100, self.kg_hidden_size)
        self.context_linear = nn.Linear(self.hidden_size, self.kg_hidden_size)

    def kg_att2_context_global(self, b, d, t, kg_input, slot_input, context_input, lengths, labels):
        # kg_input: (b*d, t*5, 300)
        # slot_input: (b*d, t, 768)
        # cotext_input: (b, d, 768)
        kg_input = kg_input.reshape(b*d, t*5, 300)
        identifiers = (kg_input[:,:,0] == 0.)

        rel = kg_input[:,:,100:200]
        tail = kg_input[:,:,200:]

        # knowledge attention
        # transform to the same space
        context_inputs = context_input.reshape(-1, self.hidden_size)
        context_inputs_transform = self.context_linear(context_inputs).unsqueeze(1)
        rel_transform = self.rel_linear(rel)
        tail_transform = self.tail_linear(tail)

        out1 = self.tanh(rel_transform+tail_transform)
        score_kg = torch.bmm(out1, context_inputs_transform.permute(0,2,1)) # (b*d, 5*t, 1)
        score_kg[identifiers] = -1e9
        score_kg = nn.Softmax(dim=1)(score_kg) # (b*d, 5*t, 1)
        fused_v_kg = torch.bmm(score_kg.permute(0,2,1), rel_transform+tail_transform).squeeze(1) # (b*d, 300)

        # intent prediction
        context_inputs_transform = context_inputs_transform.reshape(-1, self.kg_hidden_size)
        fused_intent = torch.cat([context_inputs_transform, fused_v_kg], dim=1).reshape(b, d, -1)
        rnn_out, _ = self.intent_rnn(fused_intent)
        rnn_out = self.dropout(rnn_out)
        logits = self.intent_classifier(rnn_out) # (b,d,l)

        # Remove padding
        logits_no_pad = []
        labels_no_pad = []
        for i in range(b):
            logits_no_pad.append(logits[i,:lengths[i],:])
            labels_no_pad.append(labels[i,:lengths[i],:])
        logits = torch.cat(logits_no_pad, dim=0)
        labels = torch.cat(labels_no_pad, dim=0)   

        # slot prediction
        fused_v_kg = fused_v_kg.unsqueeze(1).repeat(1,t,1)
        context_inputs_transform = context_inputs_transform.reshape(b*d,1,-1).repeat(1,t,1)
        fused_slot = torch.cat([context_inputs_transform, fused_v_kg, slot_input], dim=2)

        slot_rnn_out, _ = self.slot_rnn(fused_slot)
        slot_rnn_out = self.dropout(slot_rnn_out)
        slot_out = self.slot_classifier(slot_rnn_out)
        slot_out = slot_out.view(-1, self.num_slot_labels) # (b*d*t, num_slots)

        return logits, labels, slot_out
    
    # kg_att2_context_local, glka
    def define_kg_att2_context_local(self, num_slot_labels):
        self.kg_hidden_size = 200
        self.intent_rnn = nn.LSTM(input_size=self.kg_hidden_size*2,
                                hidden_size=self.rnn_hidden,
                                batch_first=True,
                                bidirectional=True,
                                num_layers=1)
        self.slot_rnn = nn.LSTM(input_size=self.kg_hidden_size+768,
                                hidden_size=self.rnn_hidden,
                                batch_first=True,
                                bidirectional=True,
                                num_layers=1)
        self.intent_classifier = nn.Linear(2*self.rnn_hidden, self.num_labels)
        self.slot_classifier = nn.Linear(2*self.rnn_hidden, num_slot_labels)

        self.rel_linear = nn.Linear(100, self.kg_hidden_size)
        self.tail_linear = nn.Linear(100, self.kg_hidden_size)
        self.context_linear = nn.Linear(self.hidden_size*2, self.kg_hidden_size)

    def kg_att2_context_local(self, b, d, t, kg_input, slot_input, context_input, lengths, labels):
        # kg_input: (b*d, t*5, 300)
        # slot_input: (b*d, t, 768)
        # cotext_input: (b, d, 768)
        kg_input = kg_input.reshape(b*d, t*5, 300)
        identifiers = (kg_input[:,:,0] == 0.) # (b*d, t*5)
        identifiers = identifiers.unsqueeze(1).repeat(1,t,1).reshape(b*d*t, t*5)

        rel = kg_input[:,:,100:200]
        tail = kg_input[:,:,200:]
        rel = rel.unsqueeze(1).repeat(1,t,1,1).reshape(b*d*t, t*5, -1)
        tail = tail.unsqueeze(1).repeat(1,t,1,1).reshape(b*d*t, t*5, -1)

        # knowledge attention
        # transform to the same space
        slot_inputs = slot_input.reshape(-1, self.hidden_size) # (b*d*t, h)
        context_x = context_input.reshape(-1, self.hidden_size).unsqueeze(1).repeat(1,t,1).reshape(-1, self.hidden_size) # (b*d*t, h)
        fused = torch.cat([context_x, slot_inputs], dim=-1)
        slot_inputs_transform = self.context_linear(fused).unsqueeze(1)  # (b*d*t, 1, 200)
        rel_transform = self.rel_linear(rel) # (b*d*t, 5*t, 200)
        tail_transform = self.tail_linear(tail) # (b*d*t, 5*t, 200)

        out1 = self.tanh(rel_transform+tail_transform)
        score_kg = torch.bmm(out1, slot_inputs_transform.permute(0,2,1)) # (b*d*t, 5*t, 1)
        score_kg[identifiers] = -1e9
        score_kg = nn.Softmax(dim=1)(score_kg) # (b*d*t, 5*t, 1)
        fused_v_kg = torch.bmm(score_kg.permute(0,2,1), rel_transform+tail_transform).squeeze(1) # (b*d*t, 200)

        # intent prediction
        slot_inputs_transform = slot_inputs_transform.reshape(-1, self.kg_hidden_size) # (b*d*t, 200)
        fused_intent = torch.cat([slot_inputs_transform, fused_v_kg], dim=1)
        fused_intent = fused_intent.reshape(b*d, t, -1).mean(dim=1).reshape(b, d, -1)
        rnn_out, _ = self.intent_rnn(fused_intent)
        rnn_out = self.dropout(rnn_out)
        logits = self.intent_classifier(rnn_out) # (b,d,l)

        # Remove padding
        logits_no_pad = []
        labels_no_pad = []
        for i in range(b):
            logits_no_pad.append(logits[i,:lengths[i],:])
            labels_no_pad.append(labels[i,:lengths[i],:])
        logits = torch.cat(logits_no_pad, dim=0)
        labels = torch.cat(labels_no_pad, dim=0)   

        # slot prediction
        fused_v_kg = fused_v_kg.reshape(b*d, t, -1) # (b*d, t, 200)
        slot_input = slot_input.reshape(b*d, t, -1) # (b*d, t, h)
        fused_slot = torch.cat([fused_v_kg, slot_input], dim=2)

        slot_rnn_out, _ = self.slot_rnn(fused_slot)
        slot_rnn_out = self.dropout(slot_rnn_out)
        slot_out = self.slot_classifier(slot_rnn_out)
        slot_out = slot_out.view(-1, self.num_slot_labels) # (b*d*t, num_slots)

        return logits, labels, slot_out#, score_kg.squeeze(-1).reshape(b, d, t, -1)
    
    # glka_trans
    def define_glka_trans(self, num_slot_labels):
        self.kg_hidden_size = 200
        self.decoder = ContextAttention(self.device, dim=400, layer=2)
        self.slot_hidden = nn.Linear(768, 200)
        self.intent_classifier = nn.Linear(400, self.num_labels)
        self.slot_classifier = nn.Linear(400, num_slot_labels)

        self.rel_linear = nn.Linear(100, self.kg_hidden_size)
        self.tail_linear = nn.Linear(100, self.kg_hidden_size)
        self.context_linear = nn.Linear(self.hidden_size*2, self.kg_hidden_size)

    def glka_trans(self, b, d, t, kg_input, slot_input, context_input, lengths, labels):
        # kg_input: (b*d, t*5, 300)
        # slot_input: (b*d, t, 768)
        # cotext_input: (b, d, 768)
        kg_input = kg_input.reshape(b*d, t*5, 300)
        identifiers = (kg_input[:,:,0] == 0.) # (b*d, t*5)
        identifiers = identifiers.unsqueeze(1).repeat(1,t,1).reshape(b*d*t, t*5)

        rel = kg_input[:,:,100:200]
        tail = kg_input[:,:,200:]
        rel = rel.unsqueeze(1).repeat(1,t,1,1).reshape(b*d*t, t*5, -1)
        tail = tail.unsqueeze(1).repeat(1,t,1,1).reshape(b*d*t, t*5, -1)

        # knowledge attention
        # transform to the same space
        slot_inputs = slot_input.reshape(-1, self.hidden_size) # (b*d*t, h)
        context_x = context_input.reshape(-1, self.hidden_size).unsqueeze(1).repeat(1,t,1).reshape(-1, self.hidden_size) # (b*d*t, h)
        fused = torch.cat([context_x, slot_inputs], dim=-1)
        slot_inputs_transform = self.context_linear(fused).unsqueeze(1)  # (b*d*t, 1, 200)
        rel_transform = self.rel_linear(rel) # (b*d*t, 5*t, 200)
        tail_transform = self.tail_linear(tail) # (b*d*t, 5*t, 200)

        out1 = self.tanh(rel_transform+tail_transform)
        score_kg = torch.bmm(out1, slot_inputs_transform.permute(0,2,1)) # (b*d*t, 5*t, 1)
        score_kg[identifiers] = -1e9
        score_kg = nn.Softmax(dim=1)(score_kg) # (b*d*t, 5*t, 1)
        fused_v_kg = torch.bmm(score_kg.permute(0,2,1), rel_transform+tail_transform).squeeze(1) # (b*d*t, 200)

        # intent prediction
        slot_inputs_transform = slot_inputs_transform.reshape(-1, self.kg_hidden_size) # (b*d*t, 200)
        fused_intent = torch.cat([slot_inputs_transform, fused_v_kg], dim=1)
        fused_intent = fused_intent.reshape(b*d, t, -1).mean(dim=1).reshape(b, d, -1)
        intent_output, _ = self.decoder(fused_intent, torch.ones(b, d).to(self.device))
        intent_output = self.dropout(intent_output)
        logits = self.intent_classifier(intent_output) # (b,d,l)

        # Remove padding
        logits_no_pad = []
        labels_no_pad = []
        for i in range(b):
            logits_no_pad.append(logits[i,:lengths[i],:])
            labels_no_pad.append(labels[i,:lengths[i],:])
        logits = torch.cat(logits_no_pad, dim=0)
        labels = torch.cat(labels_no_pad, dim=0)   

        # slot prediction
        fused_v_kg = fused_v_kg.reshape(b*d, t, -1) # (b*d, t, 200)
        slot_input = slot_input.reshape(b*d, t, -1) # (b*d, t, h)
        slot_input = self.slot_hidden(slot_input)
        fused_slot = torch.cat([fused_v_kg, slot_input], dim=2)

        slot_output, _ = self.decoder(fused_slot, torch.ones(b*d, t).to(self.device))
        slot_output = self.dropout(slot_output)
        slot_out = self.slot_classifier(slot_output)
        slot_out = slot_out.view(-1, self.num_slot_labels) # (b*d*t, num_slots)

        return logits, labels, slot_out#, score_kg.squeeze(-1).reshape(b, d, t, -1)











