# Huggingface compatible module
import copy
import math
import random
from turtle import hideturtle
import warnings
from typing import Any, Optional, Tuple

import torch
from torch import nn
import torch.nn.functional as F
import logging

# huggingface transformers imports
from transformers import PretrainedConfig
from transformers.models.bart.modeling_bart import (
    BartEncoder,
    BartEncoderLayer,
    BartAttention,
    BartLearnedPositionalEmbedding,
    _expand_mask,
    ACT2FN,
)
from ..modeling_outputs import MemformerEncoderOutput
from .memformer_attention import MemformerAttention
from .memformer_writer import MemoryWriter

# pylint:disable=no-member

logger = logging.getLogger(__name__)


class BartEncoderLayer(nn.Module):
    def __init__(self, config: PretrainedConfig):
        super().__init__()
        self.embed_dim = config.d_model
        self.self_attn = MemformerAttention(
            embed_dim=self.embed_dim, num_heads=config.encoder_attention_heads, dropout=config.attention_dropout,
        )
        self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
        self.dropout = config.dropout
        self.activation_fn = ACT2FN[config.activation_function]
        self.activation_dropout = config.activation_dropout
        self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
        self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
        self.final_layer_norm = nn.LayerNorm(self.embed_dim)

    def forward(
        self,
        hidden_states: torch.Tensor,
        memory_states: torch.Tensor,
        attention_mask: torch.Tensor,
        layer_head_mask: torch.Tensor,
        output_attentions: bool = False,
    ):
        """
        Args:
            hidden_states (`torch.FloatTensor`): input to the layer of shape *(seq_len, batch, embed_dim)*
            attention_mask (`torch.FloatTensor`): attention mask of size
                *(batch, 1, tgt_len, src_len)* where padding elements are indicated by very large negative values.
            layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
                *(encoder_attention_heads,)*.
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
        """
        residual = hidden_states
        cache_hidden_states = hidden_states

        hidden_states, attn_weights, _ = self.self_attn(
            hidden_states=hidden_states,
            attention_mask=attention_mask,
            memory_states=memory_states,
            layer_head_mask=layer_head_mask,
            output_attentions=output_attentions,
        )

        hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
        hidden_states = residual + hidden_states
        hidden_states = self.self_attn_layer_norm(hidden_states)

        residual = hidden_states
        hidden_states = self.activation_fn(self.fc1(hidden_states))
        hidden_states = F.dropout(hidden_states, p=self.activation_dropout, training=self.training)
        hidden_states = self.fc2(hidden_states)
        hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
        hidden_states = residual + hidden_states
        hidden_states = self.final_layer_norm(hidden_states)

        if hidden_states.dtype == torch.float16 and (
            torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
        ):
            clamp_value = torch.finfo(hidden_states.dtype).max - 1000
            hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)

        outputs = (hidden_states,)

        if output_attentions:
            outputs += attn_weights

        return outputs, cache_hidden_states


class MemformerEncoder(BartEncoder):
    """
    Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
    :class:`MemformerEncoderLayer`.

    Args:
        config: MemformerConfig
        embed_tokens (nn.Embedding): output embedding
    """

    def __init__(self, config: PretrainedConfig, embed_tokens: Optional[nn.Embedding] = None):
        super().__init__(config)

        self.dropout = config.dropout
        self.layerdrop = config.encoder_layerdrop

        embed_dim = config.d_model
        self.padding_idx = config.pad_token_id
        self.max_source_positions = config.max_position_embeddings
        self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0

        if embed_tokens is not None:
            self.embed_tokens = embed_tokens
        else:
            self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx)

        self.embed_positions = BartLearnedPositionalEmbedding(config.max_position_embeddings, embed_dim,)
        self.layers = nn.ModuleList([BartEncoderLayer(config) for _ in range(config.encoder_layers)])
        self.layernorm_embedding = nn.LayerNorm(embed_dim)

        self.gradient_checkpointing = False
        # Initialize weights and apply final processing

        # memory states
        self.memory_extract_tokens = nn.Parameter(torch.empty(config.memory_extract_len, config.d_model))
        self.memory_loc_keys = nn.Parameter(torch.empty(1, config.memory_len, embed_dim))
        self.memory_forget_bias = nn.ParameterList([nn.Parameter(torch.empty(1, config.memory_len, config.d_model)) for _ in range(config.encoder_layers)])
        self.memory_layer_norms = nn.ModuleList([nn.LayerNorm(embed_dim) for _ in range(config.encoder_layers)])
        self.memory_writers = nn.ModuleList([MemoryWriter(config) for _ in range(config.encoder_layers)])

        self.post_init()
        for idx in range(config.encoder_layers):
            nn.init.normal_(self.memory_forget_bias[idx].data, std=config.init_std)
        nn.init.normal_(self.memory_loc_keys.data, std=config.init_std)
        nn.init.normal_(self.memory_extract_tokens.data, std=config.init_std)

    def forward(
        self,
        input_ids=None,
        memory_states: torch.FloatTensor = None,
        attention_mask=None,
        head_mask=None,
        inputs_embeds=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=True,
    ):
        r"""
        Args:
            see BartEncoder
        """
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )

        # retrieve input_ids and inputs_embeds
        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
        elif input_ids is not None:
            input_shape = input_ids.size()
            input_ids = input_ids.view(-1, input_shape[-1])
        elif inputs_embeds is not None:
            input_shape = inputs_embeds.size()[:-1]
        else:
            raise ValueError("You have to specify either input_ids or inputs_embeds")

        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale

        embed_pos = self.embed_positions(input_shape)

        hidden_states = inputs_embeds + embed_pos
        hidden_states = self.layernorm_embedding(hidden_states)
        hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)

        # extract states
        extract_states = self.memory_extract_tokens.unsqueeze(0).expand(hidden_states.shape[0], -1, -1)
        hidden_states = torch.cat([extract_states, hidden_states], dim=1)

        # construct memory
        batch_size = hidden_states.shape[0]
        if memory_states is None:
            memory_states = self.construct_memory(batch_size)

        # expand attention_mask
        if attention_mask is None:
            attention_mask = torch.ones_like(input_ids).bool()
        else:
            attention_mask = attention_mask.bool()

        attention_mask = F.pad(attention_mask, (self.config.memory_extract_len, 0), "constant", True)
        encoder_attention_mask = F.pad(attention_mask, (memory_states[0].shape[1], 0), "constant", True)

        # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
        encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype)
        encoder_attention_mask = encoder_attention_mask[:, :, memory_states[0].shape[1]:, :]

        encoder_states = () if output_hidden_states else None
        all_attentions = () if output_attentions else None
        all_cached_hidden_states = []

        # check if head_mask has a correct number of layers specified if desired
        if head_mask is not None:
            if head_mask.size()[0] != (len(self.layers)):
                raise ValueError(
                    f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
                )

        for idx, encoder_layer in enumerate(self.layers):
            pre_hidden_states = hidden_states
            memory_states[idx] = self.memory_layer_norms[idx](memory_states[idx] + self.memory_forget_bias[idx])

            if output_hidden_states:
                encoder_states = encoder_states + (hidden_states,)
            # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
            dropout_probability = random.uniform(0, 1)
            if self.training and (dropout_probability < self.layerdrop):  # skip the layer
                layer_outputs = (None, None)
            else:
                if self.gradient_checkpointing and self.training:

                    def create_custom_forward(module):
                        def custom_forward(*inputs):
                            return module(*inputs, output_attentions)

                        return custom_forward

                    layer_outputs, cache_hidden_states = torch.utils.checkpoint.checkpoint(
                        create_custom_forward(encoder_layer),
                        hidden_states,
                        memory_states[idx],
                        encoder_attention_mask,
                        (head_mask[idx] if head_mask is not None else None),
                    )
                else:
                    layer_outputs, cache_hidden_states = encoder_layer(
                        hidden_states,
                        memory_states[idx],
                        encoder_attention_mask,
                        layer_head_mask=(head_mask[idx] if head_mask is not None else None),
                        output_attentions=output_attentions,
                    )

                # all_cached_hidden_states.append(cache_hidden_states)
                hidden_states = layer_outputs[0]

                # Memory Writer
                memory_states[idx], memory_writer_attentions = self.memory_writers[idx](
                    memory_states[idx],
                    hidden_states=pre_hidden_states,
                    memory_loc_keys=self.memory_loc_keys,
                    attention_mask=attention_mask,
                )

            if output_attentions:
                all_attentions = all_attentions + (layer_outputs[1],)

        if output_hidden_states:
            encoder_states = encoder_states + (hidden_states,)

        return MemformerEncoderOutput(
            last_hidden_state=hidden_states,
            hidden_states=encoder_states,
            attentions=all_attentions,
            memory_states=memory_states,
            encoder_attention_mask=attention_mask,
        )

    def construct_memory(self, batch_size):
        # use the first 128 tokens in the memory vocabulary
        memory_states = [self.memory_layer_norms[idx](self.memory_forget_bias[idx]).expand(batch_size, -1, -1) for idx in range(12)]
        return memory_states

