#%%
import torch
import numpy as np
import torch.nn as nn
from torch.nn import CrossEntropyLoss, MSELoss

from .modules import BiLSTMEncoder
from .modules import MemAttention
PAD = 0
ACT2FN = {"gelu": nn.GELU(), "relu": nn.ReLU(), "elu": nn.ELU()}
#%%

class DFITNEmbedder(nn.Module):
    def __init__(self, data, word_size, char_size, char_hidden_size, hidden_size, use_char, use_em, use_reduction):
        super(DFITNEmbedder, self).__init__()
        self.data = data
        self.use_em = use_em
        self.use_char = use_char
        self.use_reduction = use_reduction

        # ------ Word Embedding ------
        self.word_emb = nn.Embedding(len(self.data.TEXT.vocab), word_size)
        self.word_emb.weight.data.copy_(self.data.TEXT.vocab.vectors)
        self.word_emb.weight.requires_grad = False
        self.embedding_size = word_size
        
        # ----- Character Embedding ------
        if use_char:
            self.max_word_len = self.data.max_word_len
            self.char_hidden_size = char_hidden_size
            self.char_emb = nn.Embedding(len(self.data.char_vocab), char_size, padding_idx=0)
            self.char_conv = nn.Conv1d(char_size, char_hidden_size, 3)
            self.char_max = nn.MaxPool1d(self.max_word_len - 3 + 1)
            self.embedding_size += char_hidden_size
        # ------ Embedding Reduction ------
        if use_reduction:
            self.reducer = nn.Linear(self.embedding_size, hidden_size - use_em*1)
            self.embedding_size = hidden_size
        # ------ Exact Match Embedding ------
        if use_em:
            self.embedding_size = self.embedding_size + (1 - use_reduction) * 1
        # ----- Position Embedding ------
        self.pos_emb = AbsolutePositionEncoding(self.embedding_size)
        self.dropout = nn.Dropout(0.2)

    def forward(self, word_s, word_t, char_s=None, char_t=None, em_s=None, em_t=None):
        emb_s = self.word_emb(word_s)
        emb_t = self.word_emb(word_t)
        mask_s = word_s.ne(0).to(dtype=torch.long)
        mask_t = word_t.ne(0).to(dtype=torch.long)
        if self.use_char:
            char_len_s = char_s.size(1)
            char_s = char_s.view(-1, self.max_word_len)
            char_s = self.char_conv(self.char_emb(char_s).permute(0, 2, 1))
            char_s = self.char_max(char_s).squeeze()
            char_s = char_s.view(-1, char_len_s, self.char_hidden_size)
            emb_s = torch.cat([emb_s, char_s], dim=-1)

            char_len_t = char_t.size(1)
            char_t = char_t.view(-1, self.max_word_len)
            char_t = self.char_conv(self.char_emb(char_t).permute(0, 2, 1))
            char_t = self.char_max(char_t).squeeze()
            char_t = char_t.view(-1, char_len_t, self.char_hidden_size)
            emb_t = torch.cat([emb_t, char_t], dim=-1)
        
        emb_s = self.dropout(emb_s)
        emb_t = self.dropout(emb_t)
        if self.use_reduction:
            emb_s = self.reducer(emb_s)
            emb_t = self.reducer(emb_t)
        
        if self.use_em:
            em_s = em_s.unsqueeze(-1)
            em_t = em_t.unsqueeze(-1)
            emb_s = torch.cat([emb_s, em_s], dim=-1)
            emb_t = torch.cat([emb_t, em_t], dim=-1)
        bsz = emb_s.size(0)
        pos_s = self.pos_emb(emb_s).expand(bsz, -1, -1)
        emb_s = emb_s + pos_s

        pos_t = self.pos_emb(emb_t).expand(bsz, -1, -1)
        emb_t = emb_t + pos_t

        emb_s = emb_s * mask_s.unsqueeze(-1)
        emb_t = emb_t * mask_t.unsqueeze(-1)
        return emb_s, emb_t, mask_s, mask_t



class DFITNLayer(nn.Module):
    def __init__(self, input_size, hidden_size, current_mat_size, global_mat_size, attention_input_size, activation, encode_mode='LSTM', att_mode='M', conn_mode='OR', fusion_mode='C', layer_norm_eps=1e-8, dropout=0.2):
        super(DFITNLayer, self).__init__()
        self.encode_mode = encode_mode
        self.conn_mode = conn_mode
        if encode_mode == 'LSTM':
            self.encoder = BiLSTMEncoder(input_size, hidden_size)
        self.att_updater = nn.Linear(current_mat_size, global_mat_size)
        self.attention = MemAttention(
            hidden_size=attention_input_size, #2*hidden_size
            output_size=2*hidden_size,
            attention_size=global_mat_size, 
            activation=activation, 
            att_mode=att_mode, 
            fusion_mode=fusion_mode, 
            layer_norm_eps=layer_norm_eps, 
            dropout=dropout
        )
        self.activation = ACT2FN[activation]
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, input_s, input_t, mask_s, mask_t, len_s, len_t, original_s, original_t, mask_matrix_s, mask_matrix_t, history_s=None, history_t=None, history_matrix=None):
        if self.encode_mode == "LSTM:"
            hidden_s = self.encoder(input_s, len_s)
            hidden_t = self.encoder(input_t, len_t)
        current_matrix = hidden_s.unsqueeze(2) * hidden_t.unsqueeze(1)
        if history_matrix is not None:
            current_matrix = torch.cat([current_matrix, history_matrix], -1)
        # print (current_matrix.shape)
        current_matrix = self.dropout(current_matrix)
        current_matrix = self.att_updater(current_matrix)
        current_matrix = self.activation(current_matrix)
        hidden_s, hidden_t, attention_probs_s, attention_probs_t = self.attention(hidden_s, hidden_t, current_matrix, mask_matrix_s, mask_matrix_t)
        if self.conn_mode == 'OA':
            if history_s is not None:
                hidden_s = (history_s + hidden_s)/2
                hidden_t = (history_t + hidden_t)/2
        elif self.conn_mode == 'OR':
            if history_s is not None:
                hidden_s = history_s + hidden_s
                hidden_t = history_t + hidden_t
        elif self.conn_mode == 'O':
            hidden_s = hidden_s
            hidden_t = hidden_t
        # output_s = self.dropout(hidden_s)
        # output_t = self.dropout(hidden_t)
        output_s = self.dropout(torch.cat([hidden_s, original_s], -1))
        output_t = self.dropout(torch.cat([hidden_t, original_t], -1))
        return output_s, output_t, current_matrix, hidden_s, hidden_t, attention_probs_s, attention_probs_t


class DFITNEncoder(nn.Module):
    def __init__(self, input_size, hidden_size, global_mat_size, num_layers, activation, encode_mode, att_mode, fusion_mode, conn_mode, layer_norm_eps, dropout=0.2):
        super(DFITNEncoder, self).__init__()
        if encode_mode == 'LSTM':
            self.encoder_list = nn.ModuleList([
                DFITNLayer(
                    input_size=input_size if i == 0 else input_size+2*hidden_size, #
                    hidden_size=hidden_size, 
                    current_mat_size=2*hidden_size if i == 0 else global_mat_size+2*hidden_size,#
                    global_mat_size=global_mat_size,
                    attention_input_size=2*hidden_size,
                    activation=activation,
                    encode_mode=encode_mode,
                    conn_mode = conn_mode,
                    att_mode=att_mode,
                    fusion_mode=fusion_mode,
                    layer_norm_eps=layer_norm_eps,
                    dropout=dropout) for i in range(num_layers)
            ])
    def forward(self, input_s, input_t, mask_s, mask_t, len_s, len_t, mask_matrix_s, mask_matrix_t):
        original_s, original_t = input_s, input_t
        hidden_s, hidden_t, current_matrix = None, None, None
        attention_probs_list_s, attention_probs_list_t = [], []
        for i, encoder in enumerate(self.encoder_list):
            encoder_outputs = encoder(
                input_s=input_s,
                input_t=input_t, 
                mask_s=mask_s, 
                mask_t=mask_t, 
                len_s=len_s, 
                len_t=len_t, 
                original_s=original_s, 
                original_t=original_t, 
                mask_matrix_s=mask_matrix_s, 
                mask_matrix_t=mask_matrix_t, 
                history_s=hidden_s, 
                history_t=hidden_t, 
                history_matrix=current_matrix
            )
            input_s, input_t, current_matrix, hidden_s, hidden_t, attention_probs_s, attention_probs_t = encoder_outputs
            attention_probs_list_s.append(attention_probs_s)
            attention_probs_list_t.append(attention_probs_t)
        return input_s, input_t, attention_probs_list_s, attention_probs_list_t


class DFITN(nn.Module):
    def __init__(self, embed, hidden_size, global_mat_size, classifier_hidden_size, dropout, num_layers, activation, encode_mode, att_mode, conn_mode, pool_mode, fusion_mode, layer_norm_eps, num_labels):
        super(DFITN, self).__init__()
        self.embedding = embed
        self.num_labels = num_labels
        input_size = embed.embedding_size
        self.encoder = DFITNEncoder(
            input_size, hidden_size, global_mat_size, num_layers, activation, encode_mode, att_mode, fusion_mode, conn_mode, layer_norm_eps, dropout)
        self.pool_mode = pool_mode
        self.classifier = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(4*(input_size + 2*hidden_size), classifier_hidden_size),#2
            ACT2FN[activation],
            nn.Dropout(dropout),
            nn.Linear(classifier_hidden_size, num_labels)
        )
    
    @staticmethod
    def get_attn_key_pad_mask(seq_q, seq_k):
        # Expand to fit the shape of key query attention matrix.
        len_q = seq_q.size(1)
        padding_mask = seq_k.eq(PAD)
        padding_mask = padding_mask.unsqueeze(1).expand(-1, len_q, -1)  # b x lq x lk
        return padding_mask
    
    def forward(self, input_s, input_t, len_s, len_t, char_s=None, char_t=None, em_s=None, em_t=None, labels=None):
        mask_matrix_s = self.get_attn_key_pad_mask(input_s, input_t)
        mask_matrix_t = self.get_attn_key_pad_mask(input_t, input_s)
        input_s, input_t, mask_s, mask_t = self.embedding(input_s, input_t, char_s, char_t, em_s, em_t)
        hidden_s, hidden_t, attention_probs_list_s, attention_probs_list_t = self.encoder(input_s, input_t, mask_s, mask_t, len_s, len_t, mask_matrix_s, mask_matrix_t)
        if self.pool_mode == 'M':
            feature_s = hidden_s.max(1)[0]
            feature_t = hidden_t.max(1)[0]
        # feature = torch.cat([feature_s, feature_t], -1)
        feature = torch.cat([feature_s, feature_t, feature_s - feature_t, (feature_s + feature_t)/2], -1)
        # print (feature.shape, feature)
        logits = self.classifier(feature)
        outputs = (logits, attention_probs_list_s, attention_probs_list_t)
        return outputs

# %%
