# Huggingface compatible module
import copy
import math
import random
import warnings
from typing import Any, Optional, Tuple

import torch
from torch import nn
import torch.nn.functional as F
import logging

# huggingface transformers imports
from transformers.models.bart.modeling_bart import (
    BartEncoder,
    BartEncoderLayer,
    BartAttention,
    _expand_mask,
    ACT2FN,
)


class ZeroGate(nn.Module):
    """Head for sentence-level classification tasks."""

    def __init__(
        self, input_dim: int, memory_len: int,
    ):
        super().__init__()
        self.zero_weight = nn.Parameter(torch.zeros(input_dim, memory_len).T)
        self.zero_bias = nn.Parameter(torch.zeros(memory_len))
        self.act_fn = nn.GELU()

    def forward(self, hidden_states: torch.Tensor):
        gate = F.linear(hidden_states, self.zero_weight, self.zero_bias)
        gate = self.act_fn(gate)
        return gate


class MemformerWriterExtraLayer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.embed_dim = config.d_model
        self.self_attn = BartAttention(
            embed_dim=self.embed_dim, num_heads=config.encoder_attention_heads, dropout=config.attention_dropout,
        )
        self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
        self.dropout = config.dropout
        self.activation_fn = ACT2FN[config.activation_function]
        self.activation_dropout = config.activation_dropout
        # a smaller projection
        self.fc1 = nn.Linear(self.embed_dim, self.embed_dim)
        self.fc2 = nn.Linear(self.embed_dim, self.embed_dim)
        self.final_layer_norm = nn.LayerNorm(self.embed_dim)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: torch.Tensor,
        layer_head_mask: torch.Tensor = None,
        output_attentions: bool = False,
    ):
        """
        Args:
            hidden_states (`torch.FloatTensor`): input to the layer of shape *(seq_len, batch, embed_dim)*
            attention_mask (`torch.FloatTensor`): attention mask of size
                *(batch, 1, tgt_len, src_len)* where padding elements are indicated by very large negative values.
            layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
                *(encoder_attention_heads,)*.
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
        """
        residual = hidden_states
        cache_hidden_states = hidden_states
        expanded_attention_mask = _expand_mask(attention_mask, hidden_states.dtype)

        hidden_states, attn_weights, _ = self.self_attn(
            hidden_states=hidden_states,
            attention_mask=expanded_attention_mask,
            layer_head_mask=layer_head_mask,
            output_attentions=output_attentions,
        )

        hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
        hidden_states = residual + hidden_states
        hidden_states = self.self_attn_layer_norm(hidden_states)

        residual = hidden_states
        hidden_states = self.activation_fn(self.fc1(hidden_states))
        hidden_states = F.dropout(hidden_states, p=self.activation_dropout, training=self.training)
        hidden_states = self.fc2(hidden_states)
        hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
        hidden_states = residual + hidden_states
        hidden_states = self.final_layer_norm(hidden_states)

        if hidden_states.dtype == torch.float16 and (
            torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
        ):
            clamp_value = torch.finfo(hidden_states.dtype).max - 1000
            hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)

        outputs = (hidden_states,)

        if output_attentions:
            outputs += attn_weights

        return outputs, cache_hidden_states


class MemoryWriter(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.num_heads = config.memory_writer_num_heads
        self.embed_dim = config.d_model
        self.head_dim = self.embed_dim // self.num_heads
        self.inner_dim = self.num_heads * self.head_dim
        self.dropout = config.dropout

        self.block = MemformerWriterExtraLayer(config)

        self.q_proj = nn.Linear(self.embed_dim, self.inner_dim)
        self.k_proj = nn.Linear(self.embed_dim, self.inner_dim)
        self.v_proj = nn.Linear(self.embed_dim, self.inner_dim)
        self.weight_net = ZeroGate(self.embed_dim, memory_len=config.memory_len)

        # keep the variance around 1
        self.attn_tau = config.memory_writer_attn_tau
        self.scale = 1.0 / (self.head_dim ** 0.5)

    def forward(
        self,
        memory_states,
        hidden_states,
        memory_loc_keys: torch.FloatTensor = None,
        attention_mask: torch.FloatTensor = None,
    ):
        """ Write hidden_states into memory_states.
        """
        batch_size = memory_states.size(0)
        memory_len = memory_states.size(1)
        encoder_key_len = hidden_states.size(1)

        # additional projection layer
        layer_outputs, _ = self.block(hidden_states, attention_mask)
        hidden_states = layer_outputs[0]
        cls_hidden_states = hidden_states[:, 0]
        weight_gate = self.weight_net(cls_hidden_states)

        # query shape: (batch, head, seq_length, head_features)
        query = self.q_proj(memory_states).view(batch_size, memory_len, self.num_heads, self.head_dim).transpose(1, 2)

        # key shape: (batch_size, num_heads, head_size, tgt_len)
        # value shape: (batch_size, num_heads, tgt_len, head_size)
        combined_states = torch.cat([memory_states, hidden_states], dim=1)
        key = (
            self.k_proj(combined_states)
            .view(batch_size, combined_states.shape[1], self.num_heads, self.head_dim)
            .permute(0, 2, 3, 1)
        )

        encoder_key = key[:, :, :, -encoder_key_len:]
        encoder_attn_logits = torch.matmul(query, encoder_key)

        # memory attention
        memory_key = key[..., :memory_len]
        if memory_loc_keys is not None:
            # global memory keys
            memory_loc_keys = memory_loc_keys.view(1, memory_loc_keys.shape[1], self.num_heads, self.head_dim).permute(0, 2, 3, 1)
            memory_key = memory_loc_keys + memory_key
        
        memory_key = memory_key.transpose(-1, -2)[..., None]
        self_attn_logits = torch.matmul(query[:, :, :, None, :], memory_key).squeeze(-1)

        # attn mask shape: (batch_size, num_heads, query_len, key_len)
        encoder_attn_logits.masked_fill_(
            ~attention_mask[:, None, None, :].expand(-1, -1, memory_len, -1), torch.finfo(encoder_attn_logits.dtype).min
        )
        # attn shape: (batch_size, num_heads, query_len, key_len)
        attn_logits = torch.cat([self_attn_logits, encoder_attn_logits], dim=-1) * self.scale / self.attn_tau
        attn_weights = torch.softmax(attn_logits, dim=-1)

        # Value
        encoder_value = (
            self.v_proj(hidden_states).view(batch_size, encoder_key_len, self.num_heads, self.head_dim).transpose(1, 2)
        )
        self_value = memory_states.view(batch_size, memory_len, self.num_heads, self.head_dim).transpose(1, 2)

        encoder_attn_probs = attn_weights[:, :, :, -encoder_key_len:]
        self_attn_probs = attn_weights[:, :, :, :1]

        encoder_attn_out = torch.matmul(encoder_attn_probs, encoder_value)
        self_attn_out = self_attn_probs * self_value

        out_memory_states = encoder_attn_out + self_attn_out
        out_memory_states = out_memory_states.transpose(1, 2).reshape(batch_size, memory_len, self.embed_dim)

        # forget current information
        out_memory_states = out_memory_states * weight_gate[...,None]

        return out_memory_states, self_attn_probs
