from multiprocessing import context
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
from transformers import BertTokenizer, BertModel, AutoModel
from model.CHAN import ContextAttention
from model.mia import MutualIterativeAttention
from model.kencoder import KnowledgeEncoder


class BertFuse(nn.Module):

    def __init__(self, config, opt, num_labels=2, num_slot_labels=144):
        super(BertFuse, self).__init__()
        self.opt = opt
        self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
        self.model_mode = opt.model
        self.num_labels = num_labels
        self.num_slot_labels = num_slot_labels
        self.hidden_size = config.hidden_size
        self.kg_size = 1000
        self.rnn_hidden = opt.rnn_hidden

        # self.bert = BertModel.from_pretrained('bert-base-uncased', output_hidden_states=True, output_attentions=True)
        self.bert = BertModel.from_pretrained('bert-base-uncased', config=config)
        self.dropout = nn.Dropout(0.1)

        # Knowledge encoder
        self.knowledge_encoder = KnowledgeEncoder(self.device, self.hidden_size, self.kg_size)

        # Context encoder
        self.context_encoder = ContextAttention(self.device)

        # Decoder
        self.intent_rnn = nn.LSTM(input_size=2*self.hidden_size,
                                hidden_size=self.hidden_size,
                                batch_first=True,
                                bidirectional=False,
                                num_layers=1)
        self.slot_rnn = nn.LSTM(input_size=self.hidden_size+self.kg_size,
                                hidden_size=self.hidden_size,
                                batch_first=True,
                                bidirectional=True,
                                num_layers=1)
        self.intent_classifier = nn.Linear(self.hidden_size, self.num_labels)
        self.slot_classifier = nn.Linear(2*self.hidden_size, num_slot_labels)

        # DiSAN
        self.conv1 = nn.Conv1d(self.hidden_size, self.hidden_size, 3, padding=1)
        self.conv2 = nn.Conv1d(self.hidden_size, self.hidden_size, 3, padding=1)
        self.fc1 = nn.Linear(2*self.hidden_size, self.rnn_hidden)

        print('Model: ', self.model_mode)
    
    def DiSAN(self, pooled_output, d, b):
        # input should be (b,h,d)
        pooled_score = pooled_output.view(b,self.hidden_size,d)
        pooled_score = torch.sigmoid(self.conv1(pooled_score))
        pooled_score = self.conv2(pooled_score)
        pooled_score = F.softmax(pooled_score, dim=-1)
        pooled_score = pooled_score.view(b,d,self.hidden_size)
        pooled_output = pooled_score * pooled_output
        return pooled_output
    
    def forward(self, result_ids, result_token_masks, result_masks, lengths, result_slot_labels, labels, 
                y_caps, y_masks, slot_caps, slot_masks, result_kg):
        """
        Inputs:
        result_ids:         (b, d, t)
        result_token_masks: (b, d, t)
        result_masks:       (b, d)
        lengths:            (b)
        result_slot_labels: (b, d, t)
        labels:             (b, d, l)
        result_kg:          (b, d, t, 5, 300)

        BERT outputs:
        last_hidden_states: (bxd, t, h)
        pooled_output: (bxd, h), from output of a linear classifier + tanh
        hidden_states: 13 x (bxd, t, h), embed to last layer embedding
        attentions: 12 x (bxd, num_heads, t, t)
        """
        b,d,t = result_ids.shape

        ############### 1. Token-level BERT encoding ###############
        result_ids = result_ids.view(-1, t)
        result_token_masks = result_token_masks.view(-1, t)
        result_kg = result_kg[:, :, :, :, 100:]
        result_kg = result_kg.reshape(b*d, t, -1)
        context_token_output, context_pooled_output = self.bert(result_ids, attention_mask=result_token_masks)
        context_token_output = context_token_output.view(-1,t,self.hidden_size)

        if self.opt.run_baseline == 'ernie':

            context_hidden, knowledge_hidden = self.knowledge_encoder(context_token_output, result_token_masks, result_kg)
            
            # intent
            context_hidden = context_hidden[:, 0, :].reshape(b, d, -1)
            logits = self.intent_classifier(context_hidden) # (b,d,l)

            # Remove padding
            logits_no_pad = []
            labels_no_pad = []
            for i in range(b):
                logits_no_pad.append(logits[i,:lengths[i],:])
                labels_no_pad.append(labels[i,:lengths[i],:])
            logits = torch.cat(logits_no_pad, dim=0)
            labels = torch.cat(labels_no_pad, dim=0)

            # slot
            fused_input = torch.cat([context_token_output, knowledge_hidden], dim=2)
            slot_rnn_out, _ = self.slot_rnn(fused_input)
            slot_rnn_out = self.dropout(slot_rnn_out)
            slot_out = self.slot_classifier(slot_rnn_out)
            slot_out = slot_out.view(-1, self.num_slot_labels) # (b*d*t, num_slots)

            return logits, labels, slot_out
        
        else:

            ############### 2. Knowledge encoding ###############
            context_hidden, knowledge_hidden = self.knowledge_encoder(context_token_output, result_token_masks, result_kg)
            context_hidden = context_hidden[:, 0, :].reshape(b, d, -1)

            ############### 3. Context encoding ###############
            # context_hidden = context_hidden.view(b, d, -1)
            context_vector, ffscores = self.context_encoder(context_hidden, result_masks)
            # context_vector = self.DiSAN(context_hidden, d, b)
            context_pooled_output = context_pooled_output.view(b,d,self.hidden_size)
            pooled_output = torch.cat([context_pooled_output, context_vector], dim=-1)

            ############### 4. Intent: RNN prediction ###############
            rnn_out, _ = self.intent_rnn(pooled_output)
            rnn_out = self.dropout(rnn_out)
            logits = self.intent_classifier(rnn_out) # (b,d,l)

            # Remove padding
            logits_no_pad = []
            labels_no_pad = []
            for i in range(b):
                logits_no_pad.append(logits[i,:lengths[i],:])
                labels_no_pad.append(labels[i,:lengths[i],:])
            logits = torch.cat(logits_no_pad, dim=0)
            labels = torch.cat(labels_no_pad, dim=0) 

            ############### 5. Slot: RNN prediction ###############
            # knowledge_hidden = knowledge_hidden.unsqueeze(1).repeat(1, t, 1)
            # Use context to initialize RNN
            # context_initial = pooled_output.reshape(b*d, -1)
            # h_0 = context_initial.unsqueeze(0).repeat(2,1,1)
            # c_0 = context_initial.unsqueeze(0).repeat(2,1,1)
            fused_input = torch.cat([context_token_output, knowledge_hidden], dim=2)
            slot_rnn_out, _ = self.slot_rnn(fused_input)
            slot_rnn_out = self.dropout(slot_rnn_out)
            slot_out = self.slot_classifier(slot_rnn_out)
            slot_out = slot_out.reshape(-1, self.num_slot_labels) # (b*d*t, num_slots)

            return logits, labels, slot_out













        











