import torch
import torch.nn as nn
import numpy as np

ACT2FN = {"gelu": nn.GELU(), "relu": nn.ReLU(), "elu": nn.ELU()}




class AttentionFusion(nn.Module):
    def __init__(self, hidden_size, output_size, activation="elu", mode="C", layer_norm_eps=1e-8, dropout=0.1):
        '''
        hidden_size: input size
        activation: activation function
        mode: [C, R, A, F] 
            ->  C: cat[I, A]; R: residual; 
                A: average; F: cat[I, A, I-A, I*A]
        '''
        super(AttentionFusion, self).__init__()
        self.mode = mode
        if self.mode == 'C':
            self.output = nn.Linear(hidden_size*2, output_size)
        elif self.mode == 'F':
            self.output = nn.Linear(hidden_size*4, output_size)
        self.dropout = nn.Dropout(dropout)
        self.activation = ACT2FN[activation]
        self.layernorm = nn.LayerNorm(output_size, eps=layer_norm_eps)
    
    def forward(self, inputs, attention_outputs):
        if self.mode == 'C':
            feature = torch.cat([inputs, attention_outputs], -1)
            outputs = self.activation(self.output(self.dropout(feature)))
        elif self.mode == 'F':
            feature = torch.cat([inputs, attention_outputs, inputs-attention_outputs, (inputs+attention_outputs)/2], -1)
            outputs = self.activation(self.output(self.dropout(feature)))
        elif self.mode == 'A':
            outputs = (inputs + attention_outputs)/2
            outputs = self.layernorm(outputs)
        elif self.mode == 'R':
            outputs = inputs + attention_outputs
            outputs = self.layernorm(outputs)
        else:
            raise AssertionError("Please select a right mode for attention fusion.")
        return outputs

class Attention(nn.Module):
    def __init__(self, attn_dropout=0.1):
        super(Attention, self).__init__()
        self.dropout = nn.Dropout(attn_dropout)
        self.softmax = nn.Softmax(dim=2)

    def forward(self, attn, v, mask=None):
        if mask is not None:
            attn = attn.masked_fill(mask, -np.inf)
        attn = self.softmax(attn)
        output = torch.bmm(attn, v)
        return output, attn

class MemAttention(nn.Module):
    def __init__(self, hidden_size, output_size, attention_size, activation='elu', att_mode='M', fusion_mode='C', layer_norm_eps=1e-8, dropout=0.1):
        super(MemAttention, self).__init__()
        self.att_mode = att_mode
        self.attention = Attention()
        if self.att_mode == 'M':
            self.pooling = nn.MaxPool2d((1, attention_size))
        elif self.att_mode == 'A':
            self.pooling = nn.AvgPool2d((1, attention_size))
        else:
            self.pooling = nn.Sequential(
                    nn.Dropout(dropout),
                    nn.Linear(attention_size, 1),
                    ACT2FN[activation]
            )
        self.fusion = AttentionFusion(
            hidden_size=hidden_size,
            output_size=output_size,
            activation=activation, 
            mode=fusion_mode, 
            layer_norm_eps=layer_norm_eps, 
            dropout=dropout)
    
    def forward(self, input_s, input_t, current_matrix, mask_matrix_s=None, mask_matrix_t=None):
        attention_matrix = self.pooling(current_matrix)
        attention_matrix = attention_matrix.squeeze(-1)
        output_s, attention_probs_s = self.attention(attention_matrix, input_t, mask_matrix_s)
        output_t, attention_probs_t = self.attention(attention_matrix.transpose(-2, -1), input_s, mask_matrix_t)
        output_s = self.fusion(input_s, output_s)
        output_t = self.fusion(input_t, output_t)
        return output_s, output_t, attention_probs_s, attention_probs_t