from model.bert_model_context import BertContextNLU
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
from transformers import BertTokenizer, BertModel, AutoModel
from model.bert_model_context import BertContextNLU
from model.transformer import TransformerModel
from model.transformer_new import Transformer
from model.CHAN import ContextAttention
from model.torchcrf import CRF
from model.mia import MutualIterativeAttention

class RecATT(BertContextNLU):
    
    def __init__(self, config, opt, num_labels=2, num_slot_labels=144):
        super(RecATT, self).__init__(config, opt, num_labels, num_slot_labels)

        self.kg_hidden_size = 200

        # level 1 transform
        self.k_linear = nn.Linear(300, 2*self.kg_hidden_size)
        self.Q_slot = nn.Linear(self.hidden_size, self.kg_hidden_size)
        self.Q_context = nn.Linear(self.hidden_size, self.kg_hidden_size)

        # knowledge cell
        self.memory = nn.ModuleList([nn.Linear(5, opt.memory_size) for i in range(opt.hop)])
        self.score = nn.Linear(2*self.kg_hidden_size, 1)

        # prediction
        self.rnn = nn.LSTM(input_size=self.kg_hidden_size*2+768, 
                           hidden_size=self.rnn_hidden,
                           batch_first=True,
                           num_layers=1)
        self.slot_rnn = nn.LSTM(input_size=self.kg_hidden_size*2+768,
                                hidden_size=self.rnn_hidden,
                                batch_first=True,
                                bidirectional=True,
                                num_layers=1)
        self.intent_classifier = nn.Linear(self.kg_hidden_size*2+768, num_labels)
        self.slot_classifier = nn.Linear(self.kg_hidden_size*2+768, num_slot_labels)


    def forward(self, result_ids, result_token_masks, result_masks, lengths, result_slot_labels, labels, 
                y_caps, y_masks, slot_caps, slot_masks, result_kg):
        """
        Inputs:
        result_ids:         (b, d, t)
        result_token_masks: (b, d, t)
        result_masks:       (b, d)
        lengths:            (b)
        result_slot_labels: (b, d, t)
        labels:             (b, d, l)
        y_caps:             (n_intent, t)
        y_masks:            (n_intent, t)
        slot_caps:          (n_slot, t)
        slot_masks:         (n_slot, t)
        result_kg:          (b, d, t, n_kg, h_kg)

        BERT outputs:
        last_hidden_states: (bxd, t, h)
        pooled_output:      (bxd, h)        , from output of a linear classifier + tanh
        hidden_states:      13 x (bxd, t, h), embed to last layer embedding
        attentions:         12 x (bxd, num_heads, t, t)
        """

        ############### 1. Token-level BERT encoding ###############
        b,d,t = result_ids.shape
        result_ids = result_ids.view(-1, t)
        result_token_masks = result_token_masks.view(-1, t)
        last_hidden_states, pooled_output, hidden_states, attentions = self.bert(result_ids, attention_mask=result_token_masks)
        
        pooled_output = pooled_output.view(b,d,self.hidden_size)
        
        ############### 2. Turn-level Context encoding ###############

        ## Turn: CHAN
        context_inputs, ffscores = self.context_encoder(pooled_output, result_masks) # (b*d, h)

        ############### 3. Knowledge recurrent attention ###############
        slot_inputs = last_hidden_states # (b*d, t, h)
        kg_inputs = result_kg # (b, d, t, 5, 300)

        # knowledge attention
        kg_inputs = kg_inputs.reshape(b*d*t, 5, 300)
        v_kg = self.tanh(self.k_linear(kg_inputs)).reshape(b*d, t, -1, 5) # (b*d, t, 2*h_kg, 5)
        q_context = self.Q_context(context_inputs).reshape(b*d, -1).unsqueeze(1).repeat(1,t,1) # (b*d, t, h_kg)
        q_slot = self.Q_slot(slot_inputs) # (b*d, t, h_kg)

        slot_query = torch.cat([q_slot, q_context], dim=-1) # (b*d, t, 2*h_kg)

        for i in range(self.opt.hop):
            hidden_states = torch.zeros(b*d, t, 2*self.kg_hidden_size).to(self.device)
            fused_v_kg = torch.randn(b*d, 2*self.kg_hidden_size).to(self.device)

            for j in range(t):

                # get knowledge memory
                kg_t = v_kg[:, j, :, :] # (b*d, 2*h_kg, 5)
                kg_memory = self.memory[i](kg_t).reshape(b*d, -1, 2*self.kg_hidden_size) # (b*d, memory_size, 2*h_kg)

                # get slot query
                slot_query_t = slot_query[:, j, :]
                query = slot_query_t + fused_v_kg
                query = query.unsqueeze(1).repeat(1,self.opt.memory_size,1)

                # attention
                score_kg = self.score(self.tanh(query + kg_memory)) # (b*d, memory_size, 1)
                score_kg = nn.Softmax(dim=1)(score_kg) # (b*d, memory_size, 1)
                fused_v_kg = torch.bmm(score_kg.permute(0,2,1), kg_memory).squeeze(1) # (b*d, 2*h_kg)
                hidden_states[:,j,:] = fused_v_kg

            slot_query = hidden_states

        ############### 4. Intent prediction ###############
        intent_kg = slot_query.mean(dim=1).reshape(b, d, -1)
        fused_intent = torch.cat([intent_kg, context_inputs], dim=2)
        # rnn_out, _ = self.rnn(fused_intent)
        rnn_out = self.dropout(fused_intent)
        logits = self.intent_classifier(rnn_out) # (b,d,l)

        # Remove padding
        logits_no_pad = []
        labels_no_pad = []
        for i in range(b):
            logits_no_pad.append(logits[i,:lengths[i],:])
            labels_no_pad.append(labels[i,:lengths[i],:])
        logits = torch.cat(logits_no_pad, dim=0)
        labels = torch.cat(labels_no_pad, dim=0)

        ############### 4. Slot prediction ###############
        slot_kg = slot_query.reshape(b*d, t, -1) # (b*d, t, 2*h_kg)
        slot_inputs = slot_inputs.reshape(b*d, t, -1) # (b*d, t, h)
        fused_slot = torch.cat([slot_kg, slot_inputs], dim=2)

        # slot_rnn_out, _ = self.slot_rnn(fused_slot)
        slot_rnn_out = self.dropout(fused_slot)
        slot_out = self.slot_classifier(slot_rnn_out)
        slot_out = slot_out.view(-1, self.num_slot_labels) # (b*d*t, num_slots)

        return logits, labels, slot_out#, score_kg