import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
from transformers import BertTokenizer, BertModel

class KASLUM(nn.Module):
    
    def __init__(self, opt, num_labels=2, num_slot_labels=10):
        super(KASLUM, self).__init__()
        self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)
        self.kg_hidden_size = 300

        self.embedding = nn.Embedding(len(self.tokenizer.vocab), 256)

        self.utterance_encoder = nn.LSTM(input_size=256, 
                           hidden_size=256,
                           bidirectional=True,
                           batch_first=True, 
                           num_layers=1)
        self.w1 = nn.Linear(512,1)
        self.u1 = nn.Linear(512,1)
        self.utterance_encoder2 = nn.LSTM(input_size=512, 
                           hidden_size=256,
                           bidirectional=True,
                           batch_first=True, 
                           num_layers=1)
        
        self.context_linear = nn.Linear(512, self.kg_hidden_size)
        self.rel_linear = nn.Linear(100, self.kg_hidden_size)
        self.tail_linear = nn.Linear(100, self.kg_hidden_size)

        self.final_rnn = nn.LSTM(input_size=512+self.kg_hidden_size, 
                           hidden_size=256,
                           bidirectional=True,
                           batch_first=True, 
                           num_layers=1)
        
        
        self.num_labels = num_labels
        self.num_slot_labels = num_slot_labels
        self.classifier_intent = nn.Linear(512, num_labels)
        self.classifier_slot = nn.Linear(512, num_slot_labels)

        self.dropout = nn.Dropout(0.1)
        
        self.opt = opt

    def forward(self, result_ids, result_token_masks, result_masks, 
                lengths, result_slot_labels, labels, y_caps, y_masks, s_caps, s_masks, kg_input):
        
        # Utterance Encoder
        b,d,t = result_ids.shape
        result_ids = result_ids.view(-1, t) # (b*d, t)
        X = self.embedding(result_ids) # (b*d, t, h)
        rnn_out, encoder_hidden = self.utterance_encoder(X) # (b*d, t, hr)

        h_hidden = rnn_out[:,-1,:].view(b,d,2*256)

        m_hidden_all = [h_hidden[:,0,:]]
        for i in range(1, d):
            history = h_hidden[:,:i,:]
            current = h_hidden[:,i,:]
            current = current.unsqueeze(1).repeat(1,i,1)
            fused = self.w1(history) + self.u1(current)
            scores = nn.Softmax(dim=1)(fused)
            m_hidden = torch.bmm(scores.transpose(2,1), history)
            m_hidden_all.append(m_hidden.squeeze(1))
        m_hidden_all = torch.stack(m_hidden_all, dim=1)

        history_encoded, _ = self.utterance_encoder2(m_hidden_all)

        # knowledge attention
        kg_input = kg_input.reshape(-1, 5, 300)
        rel = kg_input[:,:,100:200]
        tail = kg_input[:,:,200:]

        rnn_out= rnn_out.reshape(-1, 512)
        slot_inputs_transform = self.context_linear(rnn_out).unsqueeze(1)
        rel_transform = self.rel_linear(rel)
        tail_transform = self.tail_linear(tail)

        out1 = nn.Tanh()(rel_transform+tail_transform)
        score_kg = torch.bmm(out1, slot_inputs_transform.permute(0,2,1)) # (b*d*t, 5, 1)
        score_kg = nn.Softmax(dim=1)(score_kg) # (b*d*t, 5, 1)
        fused_v_kg = torch.bmm(score_kg.permute(0,2,1), rel_transform+tail_transform).squeeze(1) # (b*d*t, 300)

        # rnn
        fused_input = torch.cat([rnn_out, fused_v_kg], dim=-1)
        fused_input = fused_input.reshape(b*d, t, -1)

        slot_outputs, _ = self.final_rnn(fused_input)
        slot_outputs = self.dropout(slot_outputs)
        slot_logits = self.classifier_slot(slot_outputs)
        slot_logits = slot_logits.view(-1, self.num_slot_labels)

        # intent prediction
        intent_outputs = slot_outputs[:,-1,:].reshape(b,d,-1)
        logits = self.classifier_intent(intent_outputs)

        # Remove padding
        logits_no_pad = []
        labels_no_pad = []
        for i in range(b):
            logits_no_pad.append(logits[i,:lengths[i],:])
            labels_no_pad.append(labels[i,:lengths[i],:])
        logits = torch.cat(logits_no_pad, dim=0)
        labels = torch.cat(labels_no_pad, dim=0)

        return logits, labels, slot_logits



