# Huggingface compatible module
import copy
import math
import random
import warnings
from typing import Any, Optional, Tuple

import torch
from torch import nn
import torch.nn.functional as F
import logging

# huggingface transformers imports
from transformers.models.bart.modeling_bart import (
    BartEncoder,
    BartEncoderLayer,
    BartAttention,
    _expand_mask,
    ACT2FN,
)


class ZeroHighwayGate(nn.Module):
    """Head for sentence-level classification tasks."""

    def __init__(
        self, input_dim: int, inner_dim: int, memory_len: int, pooler_dropout: float,
    ):
        super().__init__()
        self.dense = nn.Linear(input_dim, inner_dim)
        self.dropout = nn.Dropout(p=pooler_dropout)
        self.zero_weight = nn.Parameter(torch.zeros(inner_dim, memory_len).T)
        self.zero_bias = nn.Parameter(torch.zeros(memory_len))

    def forward(self, hidden_states: torch.Tensor):
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.dense(hidden_states)
        hidden_states = torch.tanh(hidden_states)
        hidden_states = self.dropout(hidden_states)
        gate = F.linear(hidden_states, self.zero_weight, self.zero_bias)
        return gate


class MemoryWriter(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.memory_extract_len = config.memory_extract_len
        self.num_heads = config.memory_writer_num_heads
        self.embed_dim = config.d_model
        self.head_dim = self.embed_dim // self.num_heads
        self.inner_dim = self.num_heads * self.head_dim
        self.dropout = config.dropout

        self.q_proj = nn.Linear(self.embed_dim, self.inner_dim)
        self.k_proj = nn.Linear(self.embed_dim, self.inner_dim)
        self.v_proj = nn.Linear(self.embed_dim, self.inner_dim)
        self.forget_net = ZeroHighwayGate(self.embed_dim, self.embed_dim, memory_len=config.memory_len, pooler_dropout=self.dropout)

        # keep the variance around 1
        self.attn_tau = config.memory_writer_attn_tau
        self.scale = 1.0 / (self.head_dim ** 0.5)

    def forward(
        self,
        memory_states,
        hidden_states,
        memory_loc_keys: torch.FloatTensor = None,
        attention_mask: torch.FloatTensor = None,
    ):
        """ Write hidden_states into memory_states.
        """
        batch_size = memory_states.size(0)
        memory_len = memory_states.size(1)
        encoder_key_len = hidden_states.size(1)

        cls_hidden_states = hidden_states[:, 0]
        # cls_hidden_states = hidden_states[:, self.memory_extract_len]
        forget_gate = self.forget_net(cls_hidden_states)

        # query shape: (batch, head, seq_length, head_features)
        query = self.q_proj(memory_states).view(batch_size, memory_len, self.num_heads, self.head_dim).transpose(1, 2)

        # key shape: (batch_size, num_heads, head_size, tgt_len)
        # value shape: (batch_size, num_heads, tgt_len, head_size)
        combined_states = torch.cat([memory_states, hidden_states], dim=1)
        key = (
            self.k_proj(combined_states)
            .view(batch_size, combined_states.shape[1], self.num_heads, self.head_dim)
            .permute(0, 2, 3, 1)
        )

        encoder_key = key[:, :, :, -encoder_key_len:]
        encoder_attn_logits = torch.matmul(query, encoder_key)

        # global memory keys
        memory_loc_keys = memory_loc_keys.view(1, memory_loc_keys.shape[1], self.num_heads, self.head_dim).permute(
            0, 2, 3, 1
        )
        memory_loc_keys = (memory_loc_keys + key[..., :memory_len]).transpose(-1, -2)[..., None]

        self_attn_logits = torch.matmul(query[:, :, :, None, :], memory_loc_keys).squeeze(-1)

        # attn mask shape: (batch_size, num_heads, query_len, key_len)
        encoder_attn_logits.masked_fill_(
            ~attention_mask[:, None, None, :].expand(-1, -1, memory_len, -1), torch.finfo(encoder_attn_logits.dtype).min
        )
        # attn shape: (batch_size, num_heads, query_len, key_len)
        attn_logits = torch.cat([self_attn_logits, encoder_attn_logits], dim=-1) * self.scale / self.attn_tau
        attn_weights = torch.softmax(attn_logits, dim=-1)

        # Value
        encoder_value = (
            self.v_proj(hidden_states).view(batch_size, encoder_key_len, self.num_heads, self.head_dim).transpose(1, 2)
        )
        self_value = memory_states.view(batch_size, memory_len, self.num_heads, self.head_dim).transpose(1, 2)

        encoder_attn_probs = attn_weights[:, :, :, -encoder_key_len:]
        self_attn_probs = attn_weights[:, :, :, :1]

        encoder_attn_out = torch.matmul(encoder_attn_probs, encoder_value)
        self_attn_out = self_attn_probs * self_value

        out_memory_states = encoder_attn_out + self_attn_out
        out_memory_states = out_memory_states.transpose(1, 2).reshape(batch_size, memory_len, self.embed_dim)

        # forget current information
        out_memory_states = out_memory_states * forget_gate[...,None]

        return out_memory_states, self_attn_probs
