from typing import Dict, List
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass

# pylint:disable=no-member

# @dataclass
# class MemoryCrossAttnOutput:
#     out_hidden: torch.FloatTensor = None
#     cross_key_value: torch.FloatTensor = None
#     attn_weights: torch.FloatTensor = None


class MemoryCrossAttention(nn.Module):
    def __init__(
        self,
        num_heads: int,
        embed_dim: int,
        head_dim: int,
        memory_len: int,
        dropout: float = 0.0,
        dropattn: float = 0.0,
    ):
        super().__init__()

        self.num_heads = num_heads
        self.embed_dim = embed_dim
        self.head_dim = head_dim
        self.inner_dim = num_heads * head_dim
        self.memory_len = memory_len
        self.dropout = dropout
        self.dropattn = dropattn

        self.q_proj = nn.Linear(embed_dim, self.inner_dim)
        self.kv_proj = nn.Linear(embed_dim, 2 * self.inner_dim)

        self.out_proj = nn.Linear(self.inner_dim, embed_dim)

        # keep the variance around 1
        self.scale = 1.0 / (head_dim ** 0.5)

    def _compute_k_v(self, hidden: torch.Tensor, memory_keys, batch_size: int):
        # shape: (batch_size, seq_len, hidden_size * 3)
        hidden = self.kv_proj(hidden)
        # qkv shape: (batch_size, seq_len, hidden_size)
        key, value = hidden.split(self.inner_dim, dim=2)
        key = key + memory_keys

        # key shape: (batch_size, num_heads, head_size, tgt_len)
        key = key.view(batch_size, -1, self.num_heads, self.head_dim).permute(0, 2, 3, 1)
        # value shape: (batch_size, num_heads, tgt_len, head_size)
        value = value.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        return key, value

    def forward(
        self, hidden_states: torch.FloatTensor, memory_states: torch.FloatTensor, memory_keys: torch.FloatTensor = None
    ):
        """
        Args:
            hidden_states: shape (batch, query_len, dim_model)
        """
        batch_size = hidden_states.size(0)
        query_len = hidden_states.size(1)

        # query shape: (batch, head, seq_length, head_features)
        query = self.q_proj(hidden_states).view(batch_size, query_len, self.num_heads, self.head_dim).transpose(1, 2)

        # key shape: (batch_size, num_heads, head_size, tgt_len)
        # value shape: (batch_size, num_heads, tgt_len, head_size)
        key, value = self._compute_k_v(memory_states, memory_keys, batch_size)

        # shape: (batch, num_heads, query_len, key_len)
        attn_logits = torch.matmul(query, key) * self.scale

        # shape: (batch, num_heads, query_len, key_len)
        attn_weights = torch.softmax(attn_logits, dim=-1)
        attn_probs = F.dropout(attn_weights, p=self.dropattn, training=self.training)

        # shape: (batch, num_heads, query_len, dim_head)
        hidden_states = torch.matmul(attn_probs, value)
        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, query_len, self.embed_dim)

        # Output Projection
        hidden_states = self.out_proj(hidden_states)
        hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)

        return hidden_states, attn_weights
