# Huggingface compatible module
import copy
import math
import random
import warnings
from typing import Any, Optional, Tuple

import torch
from torch import nn
import torch.nn.functional as F
import logging

# huggingface transformers imports
from transformers.models.bart.modeling_bart import (
    BartEncoder,
    BartEncoderLayer,
    BartAttention,
    _expand_mask,
    ACT2FN,
)


class MemoryWriter(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.num_heads = config.memory_writer_num_heads
        self.embed_dim = config.d_model
        self.head_dim = self.embed_dim // self.num_heads
        self.inner_dim = self.num_heads * self.head_dim
        self.dropout = config.dropout

        self.mem_k_proj = nn.Linear(self.embed_dim, self.inner_dim)
        self.encoder_k_proj = nn.Linear(self.embed_dim, self.inner_dim)
        self.encoder_v_proj = nn.Linear(self.embed_dim, self.inner_dim)

        self.memory_queries = nn.Parameter(torch.empty(1, config.memory_len, self.embed_dim))

        # keep the variance around 1
        self.attn_tau = config.memory_writer_attn_tau
        self.scale = 1.0 / (self.head_dim ** 0.5)

        nn.init.normal_(self.memory_queries.data, std=0.02)

    def forward(
        self,
        memory_states,
        hidden_states,
        memory_loc_keys: torch.FloatTensor = None,
        attention_mask: torch.FloatTensor = None,
    ):
        """ Write hidden_states into memory_states.
        """
        batch_size = memory_states.size(0)
        memory_len = memory_states.size(1)
        encoder_seq_len = hidden_states.size(1)

        # key shape: (batch_size, num_heads, head_size, tgt_len)
        memory_content_keys = self.mem_k_proj(memory_states)
        memory_keys = (
            (memory_content_keys + memory_loc_keys)
            .view(batch_size, memory_len, self.num_heads, self.head_dim)
            .permute(0, 2, 1, 3)[..., None]
        )

        encoder_keys = (
            self.encoder_k_proj(hidden_states)
            .view(batch_size, encoder_seq_len, self.num_heads, self.head_dim)
            .permute(0, 2, 3, 1)
        )

        # keys = torch.cat([memory_keys, encoder_keys], dim=1)
        # values = torch.cat([memory_values, encoder_values], dim=1)

        # memory_keys = memory_keys.view(batch_size, memory_keys.shape[1], self.num_heads,
        #                                self.head_dim).permute(0, 2, 3, 1)

        # key shape: (batch_size, num_heads, head_size, tgt_len)
        # value shape: (batch_size, num_heads, tgt_len, head_size)

        # query shape: (batch, head, seq_length, head_features)
        memory_queries = (
            self.memory_queries.expand(batch_size, -1, -1)
            .view(batch_size, memory_len, self.num_heads, self.head_dim)
            .transpose(1, 2)
        )

        encoder_attn_logits = torch.matmul(memory_queries, encoder_keys)
        self_attn_logits = torch.matmul(memory_queries[:, :, :, None, :], memory_keys).squeeze(-1)

        # attn shape: (batch_size, num_heads, query_len, key_len)
        encoder_attn_logits.masked_fill_(
            ~attention_mask[:, None, None, :].expand(-1, -1, memory_len, -1), torch.finfo(encoder_attn_logits.dtype).min
        )
        attn_logits = torch.cat([self_attn_logits, encoder_attn_logits], dim=-1) * self.scale / self.attn_tau
        attn_weights = torch.softmax(attn_logits, dim=-1)

        # Value
        encoder_value = (
            self.encoder_v_proj(hidden_states)
            .view(batch_size, encoder_seq_len, self.num_heads, self.head_dim)
            .transpose(1, 2)
        )

        self_value = memory_states.view(batch_size, memory_len, self.num_heads, self.head_dim).transpose(1, 2)

        encoder_attn_probs = attn_weights[:, :, :, -encoder_seq_len:]
        self_attn_probs = attn_weights[:, :, :, :1]

        encoder_attn_out = torch.matmul(encoder_attn_probs, encoder_value)
        self_attn_out = self_attn_probs * self_value

        out_memory_states = encoder_attn_out + self_attn_out
        out_memory_states = out_memory_states.transpose(1, 2).reshape(batch_size, memory_len, self.embed_dim)

        return out_memory_states, attn_weights
