# Huggingface compatible module
import copy
import math
import random
from turtle import hideturtle
import warnings
from typing import Any, Optional, Tuple

import torch
from torch import nn
import torch.nn.functional as F
import logging


class MemformerEncoderAttention(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' paper"""

    def __init__(
        self, embed_dim: int, num_heads: int, dropout: float = 0.0, bias: bool = True,
    ):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.dropout = dropout
        self.head_dim = embed_dim // num_heads

        if (self.head_dim * num_heads) != self.embed_dim:
            raise ValueError(
                f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
                f" and `num_heads`: {num_heads})."
            )
        self.scaling = self.head_dim ** -0.5

        self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)

        self.mem_k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.mem_v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.mem_q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.mem_out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)


    def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
        return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()

    def forward(
        self,
        hidden_states: torch.Tensor,
        memory_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        layer_head_mask: Optional[torch.Tensor] = None,
        output_attentions: bool = False,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        """Input shape: Batch x Time x Channel"""
        if layer_head_mask is not None:
            raise NotImplementedError

        bsz, tgt_len, _ = hidden_states.size()
        mem_len = memory_states.size(1)
        proj_shape = (bsz * self.num_heads, -1, self.head_dim)


        # get query proj
        encoder_query_states = self.q_proj(hidden_states) * self.scaling
        # get key value proj
        encoder_key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
        encoder_value_states = self._shape(self.v_proj(hidden_states), -1, bsz)

        # get memory query proj
        memory_query_states = self.mem_q_proj(memory_states) * self.scaling
        # get memory key value proj
        memory_key_states = self._shape(self.mem_k_proj(memory_states), -1, bsz)
        memory_value_states = self._shape(self.mem_v_proj(memory_states), -1, bsz)

        # Concat memory and hidden
        query_states = torch.cat([memory_query_states, encoder_query_states], dim=1)
        query_states = self._shape(query_states, -1, bsz).view(*proj_shape)
        key_states = torch.cat([memory_key_states, encoder_key_states], dim=2).view(*proj_shape)
        value_states = torch.cat([memory_value_states, encoder_value_states], dim=2).view(*proj_shape)

        # memory reading
        query_states = self._shape(query_states, -1, bsz).view(*proj_shape)
        encoder_query_states = query_states[:, mem_len:, :]
        encoder_attn_weights = torch.bmm(encoder_query_states, key_states.transpose(1, 2))
        
        encoder_attn_weights = encoder_attn_weights.view(bsz, self.num_heads, tgt_len, -1) + attention_mask[:, :, mem_len:, :]
        encoder_attn_weights = encoder_attn_weights.view(bsz * self.num_heads, tgt_len, -1)
        encoder_attn_weights = torch.softmax(encoder_attn_weights, dim=-1)

        encoder_attn_probs = F.dropout(encoder_attn_weights, p=self.dropout, training=self.training)
        encoder_output = torch.bmm(encoder_attn_probs, value_states)
        encoder_output = encoder_output.view(bsz, self.num_heads, -1, self.head_dim).transpose(1, 2)
        encoder_output = encoder_output.reshape(bsz, -1, self.embed_dim)
        encoder_output = self.out_proj(encoder_output)

        # memory writing
        memory_query_states = query_states[:, :mem_len, :]
        memory_key_states = memory_key_states.view(bsz* self.num_heads, -1, self.head_dim)
        encoder_key_states = encoder_key_states.view(bsz* self.num_heads, -1, self.head_dim).transpose(1, 2)
        mem_self_attn_weights = torch.matmul(memory_query_states[:, :, None, :], memory_key_states[..., None]).squeeze(-1)
        mem_encoder_attn_weights  = torch.matmul(memory_query_states, encoder_key_states)

        # attn shape: (batch_size, num_heads, query_len, key_len)
        memory_attn_weights = torch.cat([mem_self_attn_weights, mem_encoder_attn_weights], dim=-1).view(bsz, self.num_heads, mem_len, -1)
        memory_attn_weights = memory_attn_weights + attention_mask[:, :, :mem_len, mem_len-1:]
        memory_attn_weights = memory_attn_weights.view(bsz * self.num_heads, mem_len, -1)
        memory_attn_weights = torch.softmax(memory_attn_weights, dim=-1)
        memory_attn_probs = F.dropout(memory_attn_weights, p=self.dropout, training=self.training)
        mem_self_output = memory_attn_probs[:, :, :1] * value_states[:, :mem_len, :]
        mem_encoder_output = torch.bmm(memory_attn_probs[:, :, 1:], value_states[:, mem_len:, :])
        memory_output = mem_self_output + mem_encoder_output
        memory_output = memory_output.view(bsz, self.num_heads, -1, self.head_dim).transpose(1, 2)
        memory_output = memory_output.reshape(bsz, -1, self.embed_dim)
        memory_output = self.mem_out_proj(memory_output)

        return encoder_output, memory_output, None