import torch
import torch.nn as nn
import torch.nn.functional as F
import math

from entmax import sparsemax


class ScaledDotAttention(nn.Module):
    def __init__(self, d_hidden, with_sparse=False, dropout_p=0.3):
        super().__init__()
        self.with_sparse = with_sparse
        self.d_hidden = d_hidden

        self.proj_q = nn.Linear(d_hidden, d_hidden)
        self.proj_k = nn.Linear(d_hidden, d_hidden)
        self.proj_v = nn.Linear(d_hidden, d_hidden)

        self.dropout = nn.Dropout(p=dropout_p)
        self.out = nn.Sequential(
            nn.Linear(d_hidden, d_hidden),
            nn.Dropout(p=dropout_p)
        )

    def forward(self, Q, K, V, q_mask=None, k_mask=None, scale=-1, with_projection=True,
                with_output=False, return_norm_score=False):
        if with_projection:
            Q = self.proj_q(Q)
            K = self.proj_k(K)
            V = self.proj_v(V)

        if scale <= 0:
            scale = math.sqrt(self.d_hidden)
        e = Q.matmul(K.transpose(-1, -2)) / scale

        masked = -2**14 if Q.dtype == torch.float16 else -2**31
        if k_mask is not None:
            # k_mask: [bs, k_len] -> [bs, 1, k_len]
            k_mask = k_mask.unsqueeze(-2)
            e.masked_fill_(k_mask == 0, masked)

        if self.with_sparse:
            a = sparsemax(e, dim=-1)
        else:
            a = F.softmax(e, dim=-1)

        a = self.dropout(a)
        attn_vec = a.matmul(V)

        if with_output:
            attn_vec = self.out(attn_vec)
        if q_mask is not None:
            # q_mask: [bs, q_len] -> [bs, .. , q_len]
            q_mask = q_mask.expand(attn_vec.shape[:-1])
            attn_vec[q_mask == 0] = 0.

        if return_norm_score:
            return attn_vec, a
        else:
            return attn_vec, e


class SelfAttention(nn.Module):
    def __init__(self, d_hidden, dropout_p=0.3, with_sparse=False, with_output=False):
        super().__init__()
        self.with_sparse = with_sparse
        self.with_output = with_output

        self.proj = nn.Linear(d_hidden, 1)
        self.dropout = nn.Dropout(p=dropout_p)

        if with_output:
            self.out = nn.Sequential(
                nn.Linear(d_hidden, d_hidden),
                nn.Dropout(p=dropout_p)
            )

    def forward(self, x, x_mask=None, return_norm_score=False):
        e = self.proj(x).squeeze(-1)

        masked = -2**14 if x.dtype == torch.float16 else -2**31
        if x_mask is not None:
            e.masked_fill_(x_mask == 0, masked)

        if self.with_sparse:
            a = sparsemax(e, dim=-1)
        else:
            a = F.softmax(e, dim=-1)

        a = self.dropout(a)
        attn_vec = a.unsqueeze(1).matmul(x).squeeze(1)

        if self.with_output:
            attn_vec = self.out(attn_vec)

        if return_norm_score:
            return attn_vec, a
        else:
            return attn_vec, e
