#coding:utf-8
import torch
import numpy as np
from torch.autograd import Variable
import torch.nn as nn
import logging
import time
from torch.nn import utils as nn_utils
import json

from .modeling_edgebert import BertConfig, BertEncoder

logging.basicConfig(level = logging.INFO)

class WordEmbedding(nn.Module):
    def __init__(self, word_embed, embed_units):
        super(WordEmbedding, self).__init__()
    
        self.embed_units = embed_units
        self.word_embedding = nn.Embedding(num_embeddings = word_embed.shape[0], embedding_dim = self.embed_units, padding_idx = 0)
        self.word_embedding.weight = nn.Parameter(torch.Tensor(word_embed).cuda())
        self.word_embedding.weight.requires_grad = True

    def forward(self, text):
        return self.word_embedding(text)

class EntityEmbedding(nn.Module):
    def __init__(self, entity_embed, trans_units):
        super(EntityEmbedding, self).__init__()
        
        self.trans_units = trans_units
        self.entity_embedding = nn.Embedding(num_embeddings = entity_embed.shape[0] + 1, embedding_dim = self.trans_units, padding_idx = 0)
        entity_embed = torch.Tensor(entity_embed)
        entity_embed = torch.cat((torch.zeros(1, self.trans_units), entity_embed), 0)
        self.entity_embedding.weight = nn.Parameter(torch.Tensor(entity_embed).cuda())
        self.entity_embedding.weight.requires_grad = True
        self.entity_linear = nn.Linear(in_features = self.trans_units, out_features = self.trans_units)

    def forward(self, entity):
        entity_emb = self.entity_embedding(entity) 
        entity_emb = self.entity_linear(entity_emb)
        return entity_emb

class RelationEmbedding(nn.Module):
    def __init__(self, relation_embed, trans_units):
        super(RelationEmbedding, self).__init__()

        self.trans_units = trans_units
        self.relation_embedding = nn.Embedding(num_embeddings = relation_embed.shape[0] + 4, embedding_dim = self.trans_units, padding_idx = 0)
        relation_embed = torch.Tensor(relation_embed)
        relation_embed = torch.cat((torch.zeros(4, self.trans_units), relation_embed), 0)
        self.relation_embedding.weight = nn.Parameter(torch.Tensor(relation_embed).cuda())
        self.relation_embedding.weight.requires_grad = True
        self.relatioin_linear = nn.Linear(in_features = self.trans_units, out_features = self.trans_units)

    def forward(self, relation):
        relation_emb = self.relation_embedding(relation)
        relation_emb = self.relatioin_linear(relation_emb)
        return relation_emb

class DialogModel(nn.Module):
    def __init__(self, config, word_embed, entity_embed, relation_embed):
        super(DialogModel, self).__init__()
        
        self.is_inference = False    
        self.trans_units = config.trans_units 
        self.embed_units = config.embed_units 
        self.units = config.units 
        self.layers = config.layers
        self.gnn_layers = config.gnn_layers
        self.symbols = config.symbols
        self.path = None

        self.word2id = None
        self.entity2id = None
        self.id2entity = None

        self.WordEmbedding = WordEmbedding(word_embed, self.embed_units)
        self.EntityEmbedding = EntityEmbedding(entity_embed, self.trans_units)
        self.RelationEmbedding = RelationEmbedding(relation_embed, self.trans_units)
        self.BertConfig = BertConfig(
            self.trans_units, config.gnn_layers, 2, self.trans_units * 4, config.linear_dropout)
        self.Bert = BertEncoder(self.BertConfig)

        self.text_encoder = nn.GRU(input_size = self.embed_units, hidden_size = self.units, num_layers = self.layers, batch_first = True)
        self.decoder = nn.GRU(input_size = self.units + self.embed_units, hidden_size = self.units, num_layers = self.layers, batch_first = True)
        self.text2node_linear = nn.Linear(in_features = self.units, out_features = self.trans_units)

        self.attn_c_linear = nn.Linear(in_features = self.units, out_features = self.units, bias = False)
        self.attn_ce_linear = nn.Linear(in_features = self.trans_units, out_features = 2 * self.units, bias = False)
        self.context_linear = nn.Linear(in_features = 3 * self.units, out_features = self.units, bias = False)
        self.logits_linear = nn.Linear(in_features = self.units, out_features = self.symbols)
        self.selector_linear = nn.Linear(in_features = self.units, out_features = 2)

        self.softmax_d1 = nn.Softmax(dim = 1)
        self.softmax_d2 = nn.Softmax(dim = 2)


    def forward(self, query_text, answer_text, graph_node, edge, graph_mask, one_hot_entity):

        query_mask = (query_text != 0).type('torch.FloatTensor').cuda()
        decoder_mask =  (answer_text != 0).type('torch.FloatTensor').cuda()     
        batch_size = query_text.shape[0]
        graph_len = graph_node.shape[1]
        graph_len_mask = (graph_node != 0).type('torch.FloatTensor').cuda()  
        decoder_len = answer_text.shape[1]

        responses_target = answer_text
        responses_id = torch.cat((torch.ones([batch_size, 1]).type('torch.LongTensor').cuda(), torch.split(answer_text, [decoder_len - 1, 1], 1)[0]), 1)

        # encode post text
        text_encoder_input = self.WordEmbedding(query_text)
        text_encoder_output, text_encoder_state = self.text_encoder(text_encoder_input, Variable(torch.zeros(self.layers, batch_size, self.units)).cuda())
        text_feature = torch.split(text_encoder_state, 1, 0)[-1]
        text_feature = torch.squeeze(text_feature, 0)

        # encode nodes
        node_embed = self.EntityEmbedding(graph_node)
        text_node = self.text2node_linear(text_feature).view(batch_size, 1, self.trans_units)

        # encode graph
        graph_input = torch.cat((text_node, node_embed), 1)
        graph_mask = torch.unsqueeze(graph_mask, 1).repeat(1, 2, 1, 1)
        edge_embedding = self.RelationEmbedding(edge)
        graph_output = self.Bert(graph_input, edge_embedding, graph_mask)
        node_output = torch.split(graph_output, [1, graph_len], 1)[1]


        # attention key and values
        decoder_input = self.WordEmbedding(responses_id)

        c_attention_keys = self.attn_c_linear(text_encoder_output)
        c_attention_values = text_encoder_output
        ce_attention_keys, ce_attention_values = torch.split(self.attn_ce_linear(node_output), [self.units, self.units], 2)
        
        decoder_state = text_encoder_state
        decoder_output = torch.empty(0).cuda()
        ce_alignments = torch.empty(0).cuda()

        context = torch.zeros([batch_size, self.units]).cuda()
        
        if not self.is_inference:
            for t in range(decoder_len):
                decoder_input_t = torch.cat((decoder_input[:,t,:], context), 1).unsqueeze(1)
                decoder_output_t, decoder_state = self.decoder(decoder_input_t, decoder_state)
                context, ce_alignments_t = self.attention(
                    decoder_output_t.squeeze(1),
                    c_attention_keys, c_attention_values, query_mask,
                    ce_attention_keys, ce_attention_values, graph_len_mask)

                ce_alignments = torch.cat((ce_alignments, ce_alignments_t.unsqueeze(1)), 1)
                decoder_output_t = context.unsqueeze(1)
                decoder_output = torch.cat((decoder_output, decoder_output_t), 1)
        
        else:
            word_index = torch.empty(0).type('torch.LongTensor').cuda()
            decoder_input_t = self.WordEmbedding(torch.ones([batch_size]).type('torch.LongTensor').cuda())
            context = torch.zeros([batch_size, self.units]).cuda()
            decoder_state = text_encoder_state
            selector = torch.empty(0).type('torch.LongTensor').cuda()


            for t in range(decoder_len):
                decoder_input_t = torch.cat((decoder_input_t, context), 1).unsqueeze(1)
                decoder_output_t, decoder_state = self.decoder(decoder_input_t, decoder_state)
                context, ce_alignments_t = self.attention(
                    decoder_output_t.squeeze(1),
                    c_attention_keys, c_attention_values, query_mask,
                    ce_attention_keys, ce_attention_values, graph_len_mask)
                
                ce_alignments = torch.cat((ce_alignments, ce_alignments_t.unsqueeze(1)), 1)

                decoder_output_t = context.unsqueeze(1)                
                decoder_input_t, word_index_t, selector_t = self.inference(decoder_output_t,\
                    ce_alignments_t, graph_node)
                word_index = torch.cat((word_index, word_index_t.unsqueeze(1)), 1)
                selector = torch.cat((selector, selector_t.unsqueeze(1)), 1)

        use_entity = torch.sum(one_hot_entity, 2)

        if not self.is_inference:
            decoder_loss, ppx, ppx_word, ppx_entity, word_neg_num, entity_neg_num \
                = self.total_loss(decoder_output, responses_target, decoder_mask, \
                    ce_alignments, use_entity, one_hot_entity)

        if self.is_inference == True:
            return word_index.cpu().numpy().tolist(), selector.cpu().numpy().tolist(), ce_alignments.cpu().detach().numpy()

        return decoder_loss, ppx, ppx_word, ppx_entity, word_neg_num, entity_neg_num

    def inference(self, decoder_output_t, entity_alignment, entity):
        
        batch_size = decoder_output_t.shape[0]
        logits = self.logits_linear(decoder_output_t.squeeze(1)) # batch * num_symbols
        selector = self.softmax_d1(self.selector_linear(decoder_output_t.squeeze(1)))
        
        (word_prob, word_t) = torch.max(selector[:,0].unsqueeze(1) * self.softmax_d1(logits), dim = 1) 
        (entity_prob, entity_index_t) = torch.max(selector[:,1].unsqueeze(1) * entity_alignment, dim = 1)
    
        selector[:,0] = selector[:,0] * word_prob
        selector[:,1] = selector[:,1] * entity_prob
        selector = torch.argmax(selector, dim = 1)
        
        word_t = word_t.cpu().numpy().tolist()
        entity_index_t = entity_index_t.cpu().numpy().tolist()

        word_index_final_t = []
        for i in range(batch_size):
            if selector[i] == 0:
                word_index_final_t.append(word_t[i])
                continue
            if selector[i] == 1:
                entity_t = int(entity[i][entity_index_t[i]])
                entity_text = self.id2entity[entity_t]
                if entity_text not in self.word2id:
                    entity_text = '_UNK'
                word_index_final_t.append(self.word2id[entity_text])
                continue

        word_index_final_t = torch.LongTensor(word_index_final_t).cuda()
        decoder_input_t = self.WordEmbedding(word_index_final_t)

        return decoder_input_t, word_index_final_t, selector

    def total_loss(self, decoder_output, responses_target, decoder_mask,\
        entity_alignment, use_entity, entity_targets):
        batch_size = decoder_output.shape[0]
        decoder_len = responses_target.shape[1]
        
        local_masks = decoder_mask.reshape([-1]).type("torch.FloatTensor").cuda()
        local_masks_word = (1 - use_entity).reshape([-1]).type("torch.FloatTensor").cuda() * local_masks
        local_masks_entity = use_entity.reshape([-1]).type("torch.FloatTensor").cuda()
        logits = self.logits_linear(decoder_output) #batch * decoder_len * num_symbols
        
        word_prob = torch.gather(self.softmax_d2(logits), 2, responses_target.unsqueeze(2)).squeeze(2)
        selector_word, selector_entity = torch.split(self.softmax_d2(self.selector_linear(decoder_output)), [1, 1], 2)
        selector_word = selector_word.squeeze(2)
        selector_entity = selector_entity.squeeze(2)

        entity_prob = torch.sum(entity_alignment * entity_targets, 2)
        ppx_prob = word_prob * (1 - use_entity) + entity_prob * use_entity
        ppx_word = word_prob * (1 - use_entity)
        ppx_entity = entity_prob * use_entity

        final_prob = word_prob * selector_word * (1 - use_entity) + entity_prob * selector_entity * use_entity
        final_loss = torch.sum(- torch.log(1e-12 + final_prob).reshape([-1]) * local_masks)
        
        sentence_ppx = torch.sum((- torch.log(1e-12 + ppx_prob).reshape([-1]) * local_masks).reshape([batch_size, -1]), 1)
        sentence_ppx_word = torch.sum((- torch.log(1e-12 + ppx_word).reshape([-1]) * local_masks_word).reshape([batch_size, -1]), 1)
        sentence_ppx_entity = torch.sum((- torch.log(1e-12 + ppx_entity).reshape([-1]) * local_masks_entity).reshape([batch_size, -1]), 1)

        selector_loss = torch.sum(- torch.log(1e-12 + selector_entity * use_entity + \
            selector_word * (1 - use_entity)).reshape([-1]) * local_masks)
        
        loss = final_loss + selector_loss
        total_size = torch.sum(local_masks)
        total_size += 1e-12


        sum_word = torch.sum(((1 - use_entity) * decoder_mask.type("torch.FloatTensor").cuda()).type("torch.FloatTensor").cuda(), 1)
        sum_entity = torch.sum(use_entity.type("torch.FloatTensor").cuda(), 1)
        word_neg_mask = (sum_word == 0).type("torch.FloatTensor").cuda()
        entity_neg_mask = (sum_entity == 0).type("torch.FloatTensor").cuda()
        word_neg_num = torch.sum(word_neg_mask)
        entity_neg_num = torch.sum(entity_neg_mask)

        sum_word = sum_word + word_neg_mask
        sum_entity = sum_entity + entity_neg_mask

        return loss / total_size, torch.sum(sentence_ppx / torch.sum(decoder_mask.type("torch.FloatTensor").cuda(), 1)), \
            torch.sum(sentence_ppx_word / sum_word), torch.sum(sentence_ppx_entity / sum_entity), \
                word_neg_num, entity_neg_num
        
    def attention(self, decoder_state, 
        c_attention_keys, c_attention_values, query_mask,
        ce_attention_keys, ce_attention_values, graph_len_mask):

        c_query = decoder_state.reshape([-1, 1, self.units])
        c_scores = torch.sum(c_attention_keys * c_query, 2)
        c_alignments = self.softmax_d1(c_scores)
        c_alignments = c_alignments * query_mask.type("torch.FloatTensor").cuda()
        c_context = torch.sum(c_alignments.unsqueeze(2) * c_attention_values, 1)

        ce_query = decoder_state.reshape([-1, 1, self.units])
        ce_scores = torch.sum(ce_attention_keys * ce_query, 2)
        ce_alignments = self.softmax_d1(ce_scores)
        ce_alignments = ce_alignments * graph_len_mask.type("torch.FloatTensor").cuda()
        ce_context = torch.sum(ce_alignments.unsqueeze(2) * ce_attention_values, 1)

        context = self.context_linear(torch.cat((decoder_state, c_context, ce_context), 1))
             
        return context, ce_alignments
