import torch
import math
import torch.nn as nn
import numpy as np

ACT2FN = {"gelu": nn.GELU(), "relu": nn.ReLU(), "elu": nn.ELU()}

class AttentionFusion(nn.Module):
    def __init__(self, hidden_size, activation="elu", mode="C", layer_norm_eps=1e-8, dropout=0.1):
        '''
        hidden_size: input size
        activation: activation function
        mode: [C, R, A, F] 
            ->  C: cat[I, A]; R: residual; 
                A: average; F: cat[I, A, I-A, I*A]
        '''
        super(AttentionFusion, self).__init__()
        self.mode = mode
        if self.mode == 'C':
            self.output = nn.Linear(hidden_size*2, hidden_size)
        elif self.mode == 'F':
            self.output = nn.Linear(hidden_size*4, hidden_size)
        self.dropout = nn.Dropout(dropout)
        self.activation = ACT2FN[activation]
        self.layernorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
    
    def forward(self, inputs, attention_outputs):
        if self.mode == 'C':
            feature = torch.cat([inputs, attention_outputs], -1)
            outputs = self.activation(self.output(self.dropout(feature)))
        elif self.mode == 'F':
            feature = torch.cat([inputs, attention_outputs, inputs-attention_outputs, (inputs+attention_outputs)/2], -1)
            outputs = self.activation(self.output(self.dropout(feature)))
        elif self.mode == 'A':
            outputs = (inputs + attention_outputs)/2
            outputs = self.layernorm(outputs)
        elif self.mode == 'R':
            outputs = inputs + attention_outputs
            outputs = self.layernorm(outputs)
        else:
            raise AssertionError("Please select a right mode for attention fusion.")
        return outputs

class MHAttention(nn.Module):
    def __init__(self, hidden_size, attention_size, activation, fusion_mode, num_attention_heads, layer_norm_eps, dropout, attention_probs_dropout_prob, output_attentions=False):
        super(MHAttention, self).__init__()
        if attention_size % num_attention_heads != 0:
            raise ValueError(
                "The hidden size (%d) is not a multiple of the number of attention "
                "heads (%d)" % (attention_size, num_attention_heads))
        self.output_attentions = output_attentions

        self.num_attention_heads = num_attention_heads
        self.attention_head_size = int(attention_size / num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        self.query = nn.Linear(hidden_size, self.all_head_size)
        self.key = nn.Linear(hidden_size, self.all_head_size)
        self.value = nn.Linear(hidden_size, self.all_head_size)
        self.attfusion = AttentionFusion(
            hidden_size=hidden_size,
            activation=activation, 
            mode=fusion_mode, 
            layer_norm_eps=layer_norm_eps, 
            dropout=dropout)
        self.dropout = nn.Dropout(attention_probs_dropout_prob)

    def transpose_for_scores(self, x):
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(*new_x_shape)
        return x.permute(0, 2, 1, 3)

    def forward(self, query, key, value, attention_mask=None):
        mixed_query_layer = self.query(query)
        mixed_key_layer = self.key(key)
        mixed_value_layer = self.value(value)

        query_layer = self.transpose_for_scores(mixed_query_layer)
        key_layer = self.transpose_for_scores(mixed_key_layer)
        value_layer = self.transpose_for_scores(mixed_value_layer)
        #print (query_layer.shape, key_layer.shape, value_layer.shape)
        # Take the dot product between "query" and "key" to get the raw attention scores.
        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
        attention_scores = attention_scores / math.sqrt(self.attention_head_size)
        #print (attention_scores.shape)
        if attention_mask is not None:
            # Apply the attention mask is (precomputed for all layers in DTModel forward() function)
            attention_scores = attention_scores + attention_mask

        # Normalize the attention scores to probabilities.
        attention_probs = nn.Softmax(dim=-1)(attention_scores)

        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
        #attention_probs = self.dropout(attention_probs)

        # Mask heads if we want to

        context_layer = torch.matmul(attention_probs, value_layer)
        #print (context_layer.shape)
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(*new_context_layer_shape)
        att_enc = self.attfusion(query, context_layer)
        outputs = (att_enc, attention_scores) if self.output_attentions else (att_enc,)
        return outputs


class DotAttention(nn.Module):
    ''' Scaled Dot-Product Attention '''
    def __init__(self, temperature, dim, attn_dropout=0.1):
        super(DotAttention, self).__init__()
        self.temperature = temperature
        self.dropout = nn.Dropout(attn_dropout)
        self.softmax = nn.Softmax(dim=2)
    def forward(self, q, k, v, mask=None):
        attn = torch.bmm(q, k.transpose(1, 2))
        attn = attn / self.temperature
        if mask is not None:
            attn = attn.masked_fill(mask, -np.inf)
        attn = self.softmax(attn)
        attn = self.dropout(attn)
        output = torch.bmm(attn, v)
        return output, attn


class wDotAttention(nn.Module):
    ''' Scaled Dot-Product Attention '''
    def __init__(self, temperature, dim, attn_dropout=0.1):
        super(wDotAttention, self).__init__()
        self.temperature = temperature
        self.W = nn.parameter.Parameter(torch.Tensor(dim, dim).uniform_(-0.05,0.05))
        self.dropout = nn.Dropout(attn_dropout)
        self.softmax = nn.Softmax(dim=2)
    def forward(self, q, k, v, mask=None):
        attn = torch.bmm(torch.matmul(q, self.W), k.transpose(1, 2))
        # attn = torch.bmm(q, k.transpose(1, 2))
        attn = attn / self.temperature
        if mask is not None:
            attn = attn.masked_fill(mask, -np.inf)
        attn = self.softmax(attn)
        attn = self.dropout(attn)
        output = torch.bmm(attn, v)
        return output, attn



class wDotCoAttention(nn.Module):
    def __init__(self, hidden_size, activation='elu', fusion_mode='C', layer_norm_eps=1e-8, dropout=0.1):
        super(wDotCoAttention, self).__init__()
        self.temperature = np.power(hidden_size, 0.5)
        self.att_layer = wDotAttention(self.temperature, hidden_size, dropout)
        self.fusion = AttentionFusion(
            hidden_size=hidden_size,
            activation=activation, 
            mode=fusion_mode, 
            layer_norm_eps=layer_norm_eps, 
            dropout=dropout)

    def forward(self, input_s, input_t, mask_matrix_s=None, mask_matrix_t=None):
        output_s, attention_probs_s = self.att_layer(input_s, input_t, input_t, mask_matrix_s)
        output_t, attention_probs_t = self.att_layer(input_t, input_s, input_s, mask_matrix_t)
        output_s = self.fusion(input_s, output_s)
        output_t = self.fusion(input_t, output_t)
        return output_s, output_t, attention_probs_s, attention_probs_t


class DotCoAttention(nn.Module):
    def __init__(self, hidden_size, activation='elu', fusion_mode='C', layer_norm_eps=1e-8, dropout=0.1):
        super(DotCoAttention, self).__init__()
        self.temperature = np.power(hidden_size, 0.5)
        self.att_layer = DotAttention(self.temperature, hidden_size, dropout)
        self.fusion = AttentionFusion(
            hidden_size=hidden_size,
            activation=activation, 
            mode=fusion_mode, 
            layer_norm_eps=layer_norm_eps, 
            dropout=dropout)

    def forward(self, input_s, input_t, mask_matrix_s=None, mask_matrix_t=None):
        output_s, attention_probs_s = self.att_layer(input_s, input_t, input_t, mask_matrix_s)
        output_t, attention_probs_t = self.att_layer(input_t, input_s, input_s, mask_matrix_t)
        output_s = self.fusion(input_s, output_s)
        output_t = self.fusion(input_t, output_t)
        return output_s, output_t, attention_probs_s, attention_probs_t