"""Implement the full attention similar to the one implemented by PyTorch's
MultiHeadAttention module. Note that this module is to be used in conjuction
with the `fast_transformers.attention.attention_layer.AttentionLayer` in order
to work."""

from math import sqrt

import torch
from torch.nn import Dropout, Module

from fast_transformers.attention_registry import AttentionRegistry, Optional, Float, \
    EventDispatcherInstance
from fast_transformers.events import EventDispatcher, AttentionEvent


class MagnitudeAttention(Module):
    """Implement the scaled dot product attention with softmax, but weights the 
    softmax logits by the magnitude of the vector.

    Arguments
    ---------
        softmax_temp: The temperature to use for the softmax attention.
                      (default: 1/sqrt(d_keys) where d_keys is computed at
                      runtime)
        attention_dropout: The dropout rate to apply to the attention
                           (default: 0.1)
        event_dispatcher: str or EventDispatcher instance to be used by this
                          module for dispatching events (default: the default
                          global dispatcher)
    """

    def __init__(self, softmax_temp=None, attention_dropout=0.1,
                 event_dispatcher=""):
        super(MagnitudeAttention, self).__init__()
        self.eps = 0.00001
        self.softmax_temp = softmax_temp
        self.dropout = Dropout(attention_dropout)
        self.event_dispatcher = EventDispatcher.get(event_dispatcher)

    def forward(self, queries, keys, values, attn_mask, query_lengths,
                key_lengths):
        """Implements the multihead softmax attention.

        Arguments
        ---------
            queries: (N, L, H, E) The tensor containing the queries
            keys: (N, S, H, E) The tensor containing the keys
            values: (N, S, H, D) The tensor containing the values
            attn_mask: An implementation of BaseMask that encodes where each
                       query can attend to
            query_lengths: An implementation of  BaseMask that encodes how
                           many queries each sequence in the batch consists of
            key_lengths: An implementation of BaseMask that encodes how
                         many queries each sequence in the batch consists of
        """
        # Extract some shapes and compute the temperature
        N, L, H, E = queries.shape
        _, S, _, D = values.shape
        softmax_temp = self.softmax_temp or 1. / sqrt(E)

        # Scale the queries instead of applying the softmax temperature to the
        # dot products
        queries = queries * softmax_temp

        # Compute the unnormalized attention and apply the masks
        QK = torch.einsum("nlhe,nshe->nhls", queries, keys)
        if not attn_mask.all_ones:
            QK = QK + attn_mask.additive_matrix
        if not key_lengths.all_ones:
            QK = QK + key_lengths.additive_matrix[:, None, None]

        # Compute the attention by weighting the softmax logits by the
        # magnitude of the values
        exp_QK = QK.exp() * values.norm(dim=-1, keepdim=True)

        # normalize
        A = exp_QK / (exp_QK.sum(dim=-1, keepdim=True) + self.eps)
        A = self.dropout(A)
        #A = self.dropout(torch.softmax(QK, dim=-1))
        # compute weighted average
        V = torch.einsum("nhls,nshd->nlhd", A, values)

        # Let the world know of the attention matrix
        self.event_dispatcher.dispatch(AttentionEvent(self, A))

        # Make sure that what we return is contiguous
        return V.contiguous()


# Register the attention implementation so that it becomes available in our
# builders
AttentionRegistry.register(
    "magnitude", MagnitudeAttention,
    [
        ("softmax_temp", Optional(Float)),
        ("attention_dropout", Optional(Float, 0.1)),
        ("event_dispatcher", Optional(EventDispatcherInstance, ""))
    ]
)

"""Implement the typical softmax attention as a recurrent cross attention
module to speed up autoregressive decoding."""

from math import sqrt

import torch
from torch.nn import Dropout, Module

from fast_transformers.attention_registry import RecurrentCrossAttentionRegistry, Optional, \
    Float, EventDispatcherInstance
from fast_transformers.events import EventDispatcher, AttentionEvent


class RecurrentMagnitudeAttention(Module):
    """Implement autoregressive softmax cross attention as a recurrent
    module.

    Arguments
    ---------
        softmax_temp: The temperature to use for the softmax attention.
                      (default: 1/sqrt(d_keys) where d_keys is computed at
                      runtime)
        attention_dropout: The dropout rate to apply to the attention
                           (default: 0.1)
        event_dispatcher: str or EventDispatcher instance to be used by this
                          module for dispatching events (default: the default
                          global dispatcher)
    """

    def __init__(self, softmax_temp=None, attention_dropout=0.1,
                 event_dispatcher=""):
        super(RecurrentMagnitudeAttention, self).__init__()
        self.eps = 0.00001
        self.softmax_temp = softmax_temp
        self.dropout = Dropout(attention_dropout)
        self.event_dispatcher = EventDispatcher.get(event_dispatcher)

    def forward(self, query, keys, values, key_lengths, state=None):
        # Extract some shapes and compute the temperature
        N, H, E = query.shape
        softmax_temp = self.softmax_temp or 1. / sqrt(E)

        # Extract the keys and values either from the arguments or the state
        if state is not None:
            keys, values = state

        # Compute the unnormalized attention and apply the key length mask
        QK = torch.einsum("nhe,nshe->nsh", query, keys)
        QK = QK + key_lengths.additive_matrix[:, :, None]

        # Compute the attention by weighting the softmax logits by the
        # magnitude of the values
        # print(values.size())
        # print(QK.size())
        # print(query.size())
        # print(keys.size())
        exp_QK = QK.exp() * values.norm(dim=-1)

        # normalize
        A = exp_QK / (exp_QK.sum(dim=1, keepdim=True) + self.eps)
        A = self.dropout(A)
        V = torch.einsum("nsh,nshd->nhd", A, values)

        # Let the world know of the attention matrix
        self.event_dispatcher.dispatch(AttentionEvent(self, A))

        # Make sure that we return a contiguous value
        return V.contiguous(), [keys, values]


# Register the attention implementation so that it becomes available in our
# builders
RecurrentCrossAttentionRegistry.register(
    "magnitude", RecurrentMagnitudeAttention,
    [
        ("softmax_temp", Optional(Float)),
        ("attention_dropout", Optional(Float, 0.1)),
        ("event_dispatcher", Optional(EventDispatcherInstance, ""))
    ]
)
