import os.path
import math
import copy
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.nn import CrossEntropyLoss
from torch.nn import CosineEmbeddingLoss

ffscores = []

class MultiHeadAttention(nn.Module):
    def __init__(self, heads, d_model, dropout=0.1):
        super().__init__()

        self.d_model = d_model
        self.d_k = d_model // heads
        self.h = heads

        self.q_linear = nn.Linear(d_model, d_model)
        self.v_linear = nn.Linear(d_model, d_model)
        self.k_linear = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)
        self.out = nn.Linear(d_model, d_model)

        self.scores = None

    def attention(self, q, k, v, d_k, mask=None, dropout=None):
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)

        if mask is not None:
            mask = mask.unsqueeze(1)
            scores = scores.masked_fill(mask == 0, -1e9)
        scores = F.softmax(scores, dim=-1)

        if dropout is not None:
            scores = dropout(scores)

        self.scores = scores
        ffscores.append(self.scores.cpu())
        output = torch.matmul(scores, v)
        return output

    def forward(self, q, k, v, mask=None):
        bs = q.size(0)

        # perform linear operation and split into h heads
        k = self.k_linear(k).view(bs, -1, self.h, self.d_k)
        q = self.q_linear(q).view(bs, -1, self.h, self.d_k)
        v = self.v_linear(v).view(bs, -1, self.h, self.d_k)

        # transpose to get dimensions bs * h * sl * d_model
        k = k.transpose(1, 2)
        q = q.transpose(1, 2)
        v = v.transpose(1, 2)

        scores = self.attention(q, k, v, self.d_k, mask, self.dropout)

        # concatenate heads and put through final linear layer
        concat = scores.transpose(1, 2).contiguous().view(bs, -1, self.d_model)
        output = self.out(concat)
        return output

    def get_scores(self):
        return self.scores

class CoAttention(nn.Module):
    def __init__(self, d_model, d_model_k, hidden_dim, dropout=0.1):
        super().__init__()
        self.d_model = d_model
        self.d_model_k = d_model_k
        self.Wb = nn.Linear(d_model_k, d_model, bias=False)
        nn.init.xavier_normal_(self.Wb.weight)
        self.Wc = nn.Linear(d_model, hidden_dim, bias=False)
        nn.init.xavier_normal_(self.Wc.weight)
        self.Wk = nn.Linear(d_model_k, hidden_dim, bias=False)
        nn.init.xavier_normal_(self.Wk.weight)

        # self.score_c = nn.Linear(hidden_dim, 1, bias=False)
        # self.score_k = nn.Linear(hidden_dim, 1, bias=False)
        # self.mapping = nn.Linear(hidden_dim, d_model)

        self.tanh = nn.Tanh()
        self.softmax = nn.Softmax(dim=2)

    def forward(self, context, knowledge, token_mask, knowledge_mask):
        """
        knowledge_mask: (b, t)
        """

        affinity = torch.bmm(self.Wb(knowledge), context.permute(0,2,1)) # (b, tk, ts)
        
        affinity_c = affinity / math.sqrt(self.d_model)
        affinity_c = affinity_c.masked_fill(token_mask == 0, -1e9)
        c_att_score = self.softmax(affinity_c) # (b, tk, ts)

        affinity_t = affinity.permute(0,2,1) / math.sqrt(self.d_model_k)
        affinity_t = affinity_t.masked_fill(knowledge_mask == 0, -1e9)
        k_att_score = self.softmax(affinity_t) # (b, ts, tk)

        context_hidden = torch.matmul(c_att_score, context)
        knowledge_hidden = torch.matmul(k_att_score, knowledge)

        # Hc = self.tanh(self.Wc(context)+torch.bmm(affinity.permute(0,2,1), self.Wk(knowledge))) # (b, t, hidden)
        # Hk = self.tanh(self.Wk(knowledge)+torch.bmm(affinity, self.Wc(context))) # (b, t, hidden)

        # scores_c = self.score_c(Hc).squeeze(2) # (b, t)
        # scores_k = self.score_k(Hk).squeeze(2) # (b, t)
        # scores_k = scores_k.masked_fill(knowledge_mask == 0, -1e9)

        # ac = self.softmax(scores_c) 
        # ak = self.softmax(scores_k) 
        # context_hidden = torch.bmm(ac.unsqueeze(1), context).squeeze(1) # (b, h)
        # knowledge_hidden = torch.bmm(ak.unsqueeze(1), knowledge).squeeze(1) # (b, h)

        return context_hidden, knowledge_hidden


def subsequent_mask(size):
    "Mask out subsequent positions."
    attn_shape = (1, size, size)
    subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8')
    return torch.from_numpy(subsequent_mask) == 0

class PositionwiseFeedForward(nn.Module):
    "Implements FFN equation."
    def __init__(self, d_model, d_ff, dropout=0.1):
        super(PositionwiseFeedForward, self).__init__()
        self.w_1 = nn.Linear(d_model, d_ff)
        self.w_2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        return self.w_2(self.dropout(F.relu(self.w_1(x))))

#################################################################

class PositionalEncoding(nn.Module):
    "Implement the PE function."

    def __init__(self, d_model, dropout, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        # Compute the positional encodings once in log space.
        pe = torch.zeros(max_len, d_model, requires_grad=False)
        position = torch.arange(0., max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0., d_model, 2) *
                             -(math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:, :x.size(1)]
        return self.dropout(x)

class SublayerConnection(nn.Module):
    """
    A residual connection followed by a layer norm.
    Note for code simplicity the norm is first as opposed to last.
    """
    def __init__(self, size, dropout):
        super(SublayerConnection, self).__init__()
        self.norm = nn.LayerNorm(size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, sublayer):
        "Apply residual connection to any sublayer with the same size."
        return x + self.dropout(sublayer(self.norm(x)))

class EncoderLayer(nn.Module):
    "Encoder is made up of self-attn and feed forward (defined below)"
    def __init__(self, size, ksize, self_attn, self_attn_k, dropout):
        super(EncoderLayer, self).__init__()
        self.self_attn = self_attn
        self.self_attn_k = self_attn_k
        self.sublayer1 = SublayerConnection(size, dropout)
        self.sublayer2 = SublayerConnection(ksize, dropout)
        self.coattlayer = CoAttention(size, ksize, 300)
        self.size = size

        # ernie
        # self.W_tilde = nn.Linear(size+ksize, size)
        # self.Wt = nn.Linear(size, size)
        # self.We = nn.Linear(size, ksize)
        # self.gelu = nn.GELU()

    def forward(self, x, y, token_mask, sequence_mask, knowledge_mask, k_mask):
        "Follow Figure 1 (left) for connections."
        x = self.sublayer1(x, lambda x: self.self_attn(x, x, x, token_mask))
        y = self.sublayer2(y, lambda y: self.self_attn_k(y, y, y, knowledge_mask))
        x, y = self.coattlayer(x, y, token_mask, knowledge_mask)

        # ernie
        # xy_cat = torch.cat([x, y], dim=-1)
        # hidden = self.gelu(self.W_tilde(xy_cat))
        # x = self.gelu(self.Wt(hidden))
        # y = self.gelu(self.We(hidden))

        return x, y

def clones(module, N):
    "Produce N identical layers."
    return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])

class Encoder(nn.Module):
    "Generic N layer decoder with masking."

    def __init__(self, layer, N):
        super(Encoder, self).__init__()
        self.layers = clones(layer, N)
        # self.norm = nn.LayerNorm(layer.size)

    def forward(self, x, y, token_mask, sequence_mask, knowledge_mask, k_mask):
        for layer in self.layers:
            x, y = layer(x, y, token_mask, sequence_mask, knowledge_mask, k_mask)
        return x, y

class KnowledgeEncoder(nn.Module):
    def __init__(self, device, dim=768, kdim=1500, layer=2):
        super(KnowledgeEncoder, self).__init__()

        self.attn_head = 4
        self.device = device

        self.add_pe_c = PositionalEncoding(dim, 0.)
        self.add_pe_k = PositionalEncoding(kdim, 0.)

        ### Knowledge encoder
        self.klayer = Encoder(EncoderLayer(dim,
                                           kdim,
                                           MultiHeadAttention(self.attn_head, dim, dropout=0.),
                                           MultiHeadAttention(self.attn_head, kdim, dropout=0.),
                                           0.1),
                              N=layer)

    def _make_aux_tensors(self, ids, len):
        token_type_ids = torch.zeros(ids.size(), dtype=torch.long).to(self.device)
        for i in range(len.size(0)):
            for j in range(len.size(1)):
                if len[i,j,0] == 0: # padding
                    break
                elif len[i,j,1] > 0: # escape only text_a case
                    start = len[i,j,0]
                    ending = len[i,j,0] + len[i,j,1]
                    token_type_ids[i, j, start:ending] = 1
        attention_mask = ids > 0
        return token_type_ids, attention_mask

    def forward(self, context_inputs, context_masks, knowledge_inputs):
        """
        context_inputs:   (b*d, t, h)
        context_masks:    (b*d, t)
        knowledge_inputs: (b*d, t, 5*300)
        """
        ds, ts = context_masks.shape
        ds, tk, h = knowledge_inputs.shape

        # word mask
        sequence_masks = torch.Tensor(ds, ts, ts).byte().to(self.device)
        for d in range(ds):
            padding_utter = (context_masks[d,:].sum(-1) != 0)
            sequence_masks[d] = padding_utter.unsqueeze(0).repeat(ts,1) & subsequent_mask(ts).to(self.device)
        
        # knowledge mask
        k_masks = (knowledge_inputs[:,:,0] != 0.).long()
        knowledge_masks = torch.Tensor(ds, tk, tk).byte().to(self.device)
        for d in range(ds):
            padding_utter = k_masks[d,:]
            knowledge_masks[d] = padding_utter.unsqueeze(0).repeat(tk,1) & padding_utter.unsqueeze(1).repeat(1,tk)
        
        # padding mask
        token_masks = torch.Tensor(ds, ts, ts).byte().to(self.device)
        for d in range(ds):
            padding_utter = context_masks[d,:]
            token_masks[d] = padding_utter.unsqueeze(0).repeat(ts,1) & padding_utter.unsqueeze(1).repeat(1,ts)
        

        # Add positional encoding
        context_inputs = self.add_pe_c(context_inputs)
        knowledge_inputs = self.add_pe_k(knowledge_inputs)

        # Knowledge encoding
        context_hidden, knowledge_hidden = self.klayer(context_inputs, knowledge_inputs, token_masks, sequence_masks, knowledge_masks, k_masks)
        return context_hidden, knowledge_hidden