# Huggingface compatible module
import copy
import math
import random
import warnings
from typing import Any, Optional, Tuple

import torch
from torch import nn
import torch.nn.functional as F
import logging

# huggingface transformers imports
from transformers.models.bart.modeling_bart import (
    BartEncoder,
    BartEncoderLayer,
    BartAttention,
    _expand_mask,
    ACT2FN,
)
from ..modeling_outputs import MemformerEncoderOutput

# pylint:disable=no-member

logger = logging.getLogger(__name__)


class MemformerEncoder(BartEncoder):
    """
    Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
    :class:`MemformerEncoderLayer`.

    Args:
        config: MemformerConfig
        embed_tokens (nn.Embedding): output embedding
    """

    def __init__(self, config, embed_tokens: Optional[nn.Embedding] = None):
        super().__init__(config, embed_tokens)
        self.memory_len = config.memory_len
        self.mem_reader = MemoryReader(config)
        self.mem_writer = MemoryWriter(config)
        self.post_init()


    def forward(
        self,
        input_ids=None,
        memory_states: torch.FloatTensor = None,
        attention_mask=None,
        head_mask=None,
        inputs_embeds=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=True,
    ):
        r"""
        Args:
            see BartEncoder
        """
        output_attentions = (
            output_attentions
            if output_attentions is not None
            else self.config.output_attentions
        )
        output_hidden_states = (
            output_hidden_states
            if output_hidden_states is not None
            else self.config.output_hidden_states
        )

        # retrieve input_ids and inputs_embeds
        if input_ids is not None and inputs_embeds is not None:
            raise ValueError(
                "You cannot specify both input_ids and inputs_embeds at the same time"
            )
        elif input_ids is not None:
            input_shape = input_ids.size()
            input_ids = input_ids.view(-1, input_shape[-1])
        elif inputs_embeds is not None:
            input_shape = inputs_embeds.size()[:-1]
        else:
            raise ValueError("You have to specify either input_ids or inputs_embeds")

        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale

        embed_pos = self.embed_positions(input_shape)

        hidden_states = inputs_embeds + embed_pos
        hidden_states = self.layernorm_embedding(hidden_states)
        hidden_states = nn.functional.dropout(
            hidden_states, p=self.dropout, training=self.training
        )

        # construct memory
        batch_size = hidden_states.shape[0]
        if memory_states is None:
            memory_states = self.construct_memory(batch_size)

        # memory reader
        memory_states = self.mem_reader(memory_states)
        hidden_states = torch.concat([memory_states, hidden_states], dim=1)
        encoder_attention_mask = None
        
        # expand attention_mask
        if attention_mask is not None:
            encoder_attention_mask = F.pad(
                attention_mask, (self.memory_len, 0), "constant", value=1
            )
            # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
            attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype)

        encoder_states = () if output_hidden_states else None
        all_attentions = () if output_attentions else None

        # check if head_mask has a correct number of layers specified if desired
        if head_mask is not None:
            if head_mask.size()[0] != (len(self.layers)):
                raise ValueError(
                    f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
                )

        for idx, encoder_layer in enumerate(self.layers):
            if output_hidden_states:
                encoder_states = encoder_states + (hidden_states,)
            # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
            dropout_probability = random.uniform(0, 1)
            if self.training and (
                dropout_probability < self.layerdrop
            ):  # skip the layer
                layer_outputs = (None, None)
            else:
                if self.gradient_checkpointing and self.training:

                    def create_custom_forward(module):
                        def custom_forward(*inputs):
                            return module(*inputs, output_attentions)

                        return custom_forward

                    layer_outputs = torch.utils.checkpoint.checkpoint(
                        create_custom_forward(encoder_layer),
                        hidden_states,
                        attention_mask,
                        (head_mask[idx] if head_mask is not None else None),
                    )
                else:
                    layer_outputs = encoder_layer(
                        hidden_states,
                        attention_mask,
                        layer_head_mask=(
                            head_mask[idx] if head_mask is not None else None
                        ),
                        output_attentions=output_attentions,
                    )

                hidden_states = layer_outputs[0]

            if output_attentions:
                all_attentions = all_attentions + (layer_outputs[1],)

        if output_hidden_states:
            encoder_states = encoder_states + (hidden_states,)

        memory_states = self.mem_writer(hidden_states, attention_mask)

        return MemformerEncoderOutput(
            last_hidden_state=hidden_states,
            hidden_states=encoder_states,
            attentions=all_attentions,
            memory_states=memory_states,
            encoder_attention_mask=encoder_attention_mask,
        )

    def construct_memory(self, batch_size):
        memory_states = self.mem_writer.final_layer_norm(self.mem_writer.memory_bias)
        memory_states = memory_states.expand(batch_size, -1, -1)
        return memory_states


class MemoryReader(nn.Module):
    def __init__(self, config) -> None:
        super().__init__()
        self.dropout = config.dropout
        self.memory_len = config.memory_len
        self.layers = nn.ModuleList(
            [BartEncoderLayer(config) for _ in range(config.memory_reader_layers)]
        )

    def forward(self, memory_states):
        for idx, encoder_layer in enumerate(self.layers):
            layer_outputs = encoder_layer(
                memory_states,
                attention_mask=None,
                layer_head_mask=None,
                output_attentions=False,
            )
            memory_states = layer_outputs[0]
        return memory_states


class MemoryFusionLayer(nn.Module):
    def __init__(self, config) -> None:
        super().__init__()
        self.embed_dim = config.d_model
        self.memory_len = config.memory_len

        self.encoder_attn = BartAttention(
            self.embed_dim,
            config.decoder_attention_heads,
            dropout=config.attention_dropout,
            is_decoder=True,
        )
        self.dropout = config.dropout
        self.activation_fn = ACT2FN[config.activation_function]
        self.activation_dropout = config.activation_dropout
        self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
        self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
        self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)

    def forward(
        self,
        memory_states: torch.Tensor,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = False,
    ):
        if attention_mask is not None:
            attention_mask = attention_mask[:, :, :self.memory_len, :]
        memory_states, cross_attn_weights, _ = self.encoder_attn(
            hidden_states=memory_states,
            key_value_states=hidden_states,
            attention_mask=attention_mask,
            past_key_value=None,
            output_attentions=output_attentions,
        )
        residual = memory_states
        memory_states = nn.functional.dropout(
            memory_states, p=self.dropout, training=self.training
        )
        memory_states = residual + memory_states
        memory_states = self.encoder_attn_layer_norm(memory_states)
        residual = memory_states
        memory_states = self.activation_fn(self.fc1(memory_states))
        memory_states = nn.functional.dropout(
            memory_states, p=self.activation_dropout, training=self.training
        )
        memory_states = self.fc2(memory_states)
        memory_states = nn.functional.dropout(
            memory_states, p=self.dropout, training=self.training
        )
        memory_states = residual + memory_states

        outputs = (memory_states,)
        return outputs


class MemoryWriter(nn.Module):
    def __init__(self, config) -> None:
        super().__init__()
        self.dropout = config.dropout
        self.memory_len = config.memory_len
        self.self_layers = nn.ModuleList(
            [BartEncoderLayer(config) for _ in range(config.memory_writer_self_layers)]
        )
        self.fusion_layers = nn.ModuleList([
            MemoryFusionLayer(config) for _ in range(config.memory_writer_fusion_layers)
        ])

        self.final_layer_norm = nn.LayerNorm(config.d_model)
        self.memory_bias = nn.Parameter(
            torch.randn(1, config.memory_len, config.d_model)
        )
        nn.init.normal_(self.memory_bias.data, std=config.init_std)


    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = False,
    ):
        # run extra encoder layer blocks
        for idx, self_layer in enumerate(self.self_layers):
            layer_outputs = self_layer(
                hidden_states,
                attention_mask=attention_mask,
                layer_head_mask=None,
                output_attentions=False,
            )
            hidden_states = layer_outputs[0]

        memory_states = hidden_states[:, : self.memory_len]

        for idx, fusion_layer in enumerate(self.fusion_layers):
            layer_outputs = fusion_layer(
                memory_states,
                hidden_states,
                attention_mask=attention_mask,
                output_attentions=False,
            )
            memory_states = layer_outputs[0]
            memory_states = self.final_layer_norm(memory_states)

        return memory_states
