import torch
import torch.nn as nn
import numpy as np
import json

from typing import Dict
from torch import Tensor
from fairseq.models.transformer import TransformerEncoder
from fairseq.models.fairseq_encoder import EncoderOut
from fairseq.modules import LayerNorm
from .utils import *

from examples.summarization.modules.adapter import Adapter

class superEncoder(TransformerEncoder):
    def __init__(self, args, dictionary, embed_tokens):
        super().__init__(args, dictionary, embed_tokens)
        self.adapter = None
        if args.adapter_num_layer != 0:
            embed_dim = embed_tokens.embedding_dim
            self.adapter = Adapter(args, args.adapter_num_layer, self.padding_idx, embed_dim)
        
        if getattr(args, "wo_encoder_last_layernorm", None):
            self.layer_norm = None # delete the final layer norm of encoder
        
        if self.layer_norm is not None and getattr(args, "tune_encoder_layer_norm", None):
            self.layer_norm.requires_grad_(True)
        
        # if args.fuse_encoder_and_adapter:
        if getattr(args, "fuse_encoder_and_adapter", None):
            if args.fuse_encoder_and_adapter == "gated":
                self.gate_fc = nn.Linear(2 * args.encoder_embed_dim, 1, bias=True)
            else:
                raise NotImplementedError
        self.doc_state = args.doc_state
        self.bn_encoder_output = getattr(args, "bn_encoder_output", False)
        self.proj_encoder_output = getattr(args, "component_config", None)
        self.ln_after_proj = getattr(args, "ln_after_proj", None)
        self.freeze_ln_after_proj = getattr(args, "freeze_ln_after_proj", False)
        self.proj_k = getattr(args, "proj_k", 6)
        self.lang_cls_input = getattr(args, "lang_cls_input", "adapter_output")
        self.postfix_tuning = getattr(args, "postfix_tuning", None)
        if self.postfix_tuning:
            encoder_embed_dim = args.encoder_embed_dim
            if self.postfix_tuning == "v1":
                # w.o. projection, replace the feature of language token
                postfix_matrix = torch.zeros(encoder_embed_dim)
            else:
                postfix_matrix = torch.zeros(128)
                self.postfix_projection = nn.Linear(128, encoder_embed_dim, bias=False)
            self.postfix_matrix = nn.Parameter(postfix_matrix)
            nn.init.normal_(self.postfix_matrix.data, mean=0, std=encoder_embed_dim ** -0.5)

        self.component_id2name = dict()
        if self.proj_encoder_output is not None:
            com_config = json.load(open(args.component_config, "r"))
            for tokenId in com_config:
                lg_token = dictionary[int(tokenId)] # e.g. [en_XX]
                lg = lg_token.split("_")[0][1:]
                filename = com_config[tokenId]
                component = np.load(filename)
                component = torch.tensor(component[:, :self.proj_k])
                self.__setattr__("{}_component".format(lg), component)
                self.component_id2name[tokenId] = "{}_component".format(lg)
            if self.ln_after_proj:
                affine = not getattr(args, "proj_ln_wo_affine", False)
                self.layer_norm_after_proj = LayerNorm(args.encoder_embed_dim, elementwise_affine=affine)

    def forward(self, src_tokens, src_lengths, return_all_hiddens=False, **kwargs):
        """
        Args:
            src_tokens (LongTensor): tokens in the source language of shape
                `(batch, src_len)`
            src_lengths (torch.LongTensor): lengths of each source sentence of
                shape `(batch)`
            return_all_hiddens (bool, optional): also return all of the
                intermediate hidden states (default: False).

        Returns:
            namedtuple:
                - **encoder_out** (Tensor): the last encoder layer's output of
                  shape `(src_len, batch, embed_dim)`
                - **encoder_padding_mask** (ByteTensor): the positions of
                  padding elements of shape `(batch, src_len)`
                - **encoder_embedding** (Tensor): the (scaled) embedding lookup
                  of shape `(batch, src_len, embed_dim)`
                - **encoder_states** (List[Tensor]): all intermediate
                  hidden states of shape `(src_len, batch, embed_dim)`.
                  Only populated if *return_all_hiddens* is True.
        """
        x, encoder_embedding = self.forward_embedding(src_tokens)

        # B x T x C -> T x B x C
        x = x.transpose(0, 1)

        # compute padding mask
        encoder_padding_mask = src_tokens.eq(self.padding_idx)
        encoder_states = [] if return_all_hiddens else None

        # encoder layers
        for layer in self.layers:
            x = layer(x, encoder_padding_mask)
            if return_all_hiddens:
                assert encoder_states is not None
                encoder_states.append(x)

        if self.layer_norm is not None:
            x = self.layer_norm(x)

        encoder_last_hidden = x
        adapter_last_hidden = None
        hidden_after_proj = None
        subtract = None
        if self.bn_encoder_output:
            # encoder_padding_mask # [B, L]
            # x: [L, B, H]
            sentence_x = x.transpose(0, 1) # [B, L, H]
            if encoder_padding_mask is not None:
                bool_encoder_padding_mask = encoder_padding_mask.to(torch.bool)
                sentence_x *= ~bool_encoder_padding_mask.unsqueeze(-1)
            sentence_x = sentence_x.mean(dim=1, keepdims=False) # [B, H]
            mean, var = getMeanAndVar(sentence_x)
            normalized_x = normalizedFn(
                x.transpose(0, 1),
                mean.unsqueeze(1), var.unsqueeze(1)
            )
            x = normalized_x.transpose(0, 1) # [B, L, H] -> [L, B, H]
        elif self.proj_encoder_output:
            # encoder_padding_mask # [B, L]
            # x: [L, B, H]
            sentence_x = x.transpose(0, 1) # [B, L, H]
            if encoder_padding_mask is not None:
                bool_encoder_padding_mask = encoder_padding_mask.to(torch.bool)
                sentence_x *= ~bool_encoder_padding_mask.unsqueeze(-1)
            sentence_x = sentence_x.mean(dim=1, keepdims=False) # [B, H]
            normalized_x = None
            for (tokenId, component_name) in self.component_id2name.items():
                tokenId = int(tokenId)
                keyToken = src_tokens[:, -1]
                mask = torch.eq(
                    keyToken, tokenId
                ).unsqueeze(0).unsqueeze(-1)
                masked_x = mask * x
                component = self.__dict__[component_name]
                # subtract: [B, H]
                result, subtract = lir(
                    masked_x.transpose(0, 1),
                    sentence_x,
                    component
                )
                result = result.transpose(0, 1)
                result = result * mask
                if normalized_x is None:
                    normalized_x = result
                else:
                    normalized_x += result
            x = normalized_x

            if self.ln_after_proj:
                x = self.layer_norm_after_proj(x)
            hidden_after_proj = x

        if getattr(self, "adapter", None):
            x_size = x.size()
            postfix_tuning = getattr(self, "postfix_tuning", None)
            if postfix_tuning:
                if postfix_tuning == "v1":
                    postfix_feature = self.postfix_matrix.expand([1, 1, -1]).repeat([1, x.size(1), 1])
                else:
                    postfix_feature = self.postfix_projection(self.postfix_matrix)
                    postfix_feature = postfix_feature.expand([1, 1, -1]).repeat([1, x.size(1), 1])
                if postfix_tuning in ["v1", "v2"]:
                    # v1 and v2 remove the feature of the language token
                    x = torch.cat((x[:-1, :, :], postfix_feature))
                    assert x_size == x.size()
                else:
                    x = torch.cat((x, postfix_feature))
                    assert x_size[0]+1 == x.size(0)
                    assert x_size[1:] == x.size()[1:]
                    encoder_padding_mask = torch.cat((encoder_padding_mask, encoder_padding_mask[:, -1:]), dim=-1)
                    src_lengths += 1
            adapter_last_hidden = self.adapter(src_tokens, x, encoder_padding_mask=encoder_padding_mask)
            if getattr(self, "gate_fc", None):
                out_tensor = self.fuse_encoder_and_adapter(encoder_last_hidden, adapter_last_hidden)
            else:
                out_tensor = adapter_last_hidden
        else:
            out_tensor = encoder_last_hidden
        encoder_doc_state = self.select_doc_state(src_tokens, encoder_last_hidden)
        doc_state = self.select_doc_state(src_tokens, out_tensor, encoder_padding_mask=encoder_padding_mask)

        extra = {"encoder_doc_state": encoder_doc_state, "doc_state": doc_state}

        if adapter_last_hidden is not None:
            adapter_doc_state = self.select_doc_state(src_tokens, adapter_last_hidden, encoder_padding_mask=encoder_padding_mask)
            extra["adapter_doc_state"] = adapter_doc_state
            if self.lang_cls_input == "adapter_output":
                extra['lang_cls_input'] = adapter_doc_state

        if self.lang_cls_input == "adapter_input":
            extra['lang_cls_input'] = self.select_doc_state(src_tokens, hidden_after_proj)

        if subtract is not None:
            extra["subtract"] = subtract

        if hidden_after_proj is not None:
            proj_doc_state = self.select_doc_state(src_tokens, hidden_after_proj)
            extra['proj_doc_state'] = proj_doc_state
        encoder_out = EncoderOut(
            encoder_out=out_tensor,  # T x B x C
            encoder_padding_mask=encoder_padding_mask,  # B x T
            encoder_embedding=encoder_embedding,  # B x T x C
            encoder_states=encoder_states,  # List[T x B x C]
            src_tokens=src_tokens,
            src_lengths=src_lengths,
        )
        return encoder_out, extra

    def select_doc_state(self, src_tokens, token_states, encoder_padding_mask=None):
        """
        Input:
            src_tokens: [B, T]
            token_states: [T, B, C]
        Return:
            doc_states: [B, C]
        """
        token_states = token_states.transpose(0, 1) # [T x B x C] -> [B x T x C]
        if encoder_padding_mask is None:
            token_mask = ~torch.eq(src_tokens, self.dictionary.pad()).unsqueeze(-1) # 0 means masked
        else:
            token_mask = ~encoder_padding_mask.unsqueeze(-1)
        
        # TODO: remove this patch
        if token_mask.size(1) != token_states.size(1):
            padding_length = token_states.size(1) - token_mask.size(1)
            padding = torch.ones((token_states.size(0), padding_length, 1), dtype=token_mask.dtype).to(token_states.device)
            token_mask = torch.cat((token_mask, padding), dim=1)

        selected_token_state = torch.where(
            token_mask, token_states, torch.tensor(0).to(token_states)
        )
        doc_state = torch.mean(
            selected_token_state, dim=-2
        )
        return doc_state

    def posttuned_fn(self, args):
        if getattr(self, "gate_fc", None) is not None:
            self.gate_fc.requires_grad_(True)
        if getattr(self, "adapter", None) is not None and not getattr(args, "freeze_adapter", False):
            self.adapter.requires_grad_(True)
        if getattr(self, "layer_norm_after_proj", None) is not None and \
            not getattr(self, "freeze_ln_after_proj", False):
            self.layer_norm_after_proj.requires_grad_(True)

    def fuse_encoder_and_adapter(self, encoder_hidden, adapter_hidden):
        fused_logit = self.gate_fc(torch.cat([encoder_hidden, adapter_hidden], dim=-1)) # [T, B, 1]
        fused_weight = torch.sigmoid(fused_logit)
        fused_hidden = fused_weight * encoder_hidden + (1 - fused_weight) * adapter_hidden
        return fused_hidden

    @torch.jit.export
    def reorder_encoder_out(self, encoder_out: EncoderOut, new_order):
        if not isinstance(encoder_out, EncoderOut):
            encoder_out_tuple = encoder_out[0]
        else: 
            encoder_out_tuple = encoder_out 
        assert isinstance(encoder_out_tuple, EncoderOut)
        return super().reorder_encoder_out(encoder_out_tuple, new_order)

    def forward_torchscript(self, net_input: Dict[str, Tensor]):
        """A TorchScript-compatible version of forward.

        Encoders which use additional arguments may want to override
        this method for TorchScript compatibility.
        """
        if torch.jit.is_scripting():
            [encoder_out, extra] = self.forward(
                src_tokens=net_input["src_tokens"],
                src_lengths=net_input["src_lengths"],
            )
            return [encoder_out, extra]
        else:
            return self.forward_non_torchscript(net_input)

    @torch.jit.unused
    def forward_non_torchscript(self, net_input: Dict[str, Tensor]):
        encoder_input = {
            k: v
            for k, v in net_input.items()
            if k != "prev_output_tokens"
        }
        encoder_out, extra = self.forward(**encoder_input)
        return [encoder_out, extra]

class superFlEncoder(superEncoder):
    def __init__(self, args, dictionary, embed_tokens):
        super().__init__(args, dictionary, embed_tokens)

        self.fusing_method = args.fl_method
        if self.fusing_method == "gated":
            self.fl_gate = nn.Linear(2 * args.encoder_embed_dim, 1, bias=True)
        elif self.fusing_method == "gated_w_lang_agnostic":
            self.fl_gate = nn.Linear(args.encoder_embed_dim, 1, bias=True)
        elif self.fusing_method == "concat":
            self.concat_proj = nn.Linear(
                2 * args.encoder_embed_dim, args.encoder_embed_dim, bias=True
            )

    def fuse_lang_specific_agnostic_fn(
        self,lang_specific_hidden, lang_agnostic_hidden
    ):
        """
        Input:
            lang_specific_hidden: [T, B, H]
            lang_agnostic_hidden: [T, B, H]
        Return:
            fused_hidden: [T, B, H]
        """
        if self.fusing_method == "sum":
            return lang_specific_hidden + lang_agnostic_hidden
        elif self.fusing_method == "gated":
            fused_logit = self.fl_gate(torch.cat([lang_specific_hidden, lang_agnostic_hidden], dim=-1)) # [T, B, 1]
            fused_weight = torch.sigmoid(fused_logit)
            return fused_weight * lang_specific_hidden + (1 - fused_weight) * lang_agnostic_hidden
        elif self.fusing_method == "gated_w_lang_agnostic":
            fused_logit = self.fl_gate(lang_agnostic_hidden)  # [T, B, 1]
            fused_weight = torch.sigmoid(fused_logit)
            return fused_weight * lang_specific_hidden + (1 - fused_weight) * lang_agnostic_hidden
        elif self.fusing_method == "concat":
            fused_hidden = self.concat_proj(torch.cat([lang_specific_hidden, lang_agnostic_hidden], dim=-1))
            return fused_hidden

    def forward(self, src_tokens, src_lengths, return_all_hiddens=False, **kwargs):
        """
        Args:
            src_tokens (LongTensor): tokens in the source language of shape
                `(batch, src_len)`
            src_lengths (torch.LongTensor): lengths of each source sentence of
                shape `(batch)`
            return_all_hiddens (bool, optional): also return all of the
                intermediate hidden states (default: False).

        Returns:
            namedtuple:
                - **encoder_out** (Tensor): the last encoder layer's output of
                  shape `(src_len, batch, embed_dim)`
                - **encoder_padding_mask** (ByteTensor): the positions of
                  padding elements of shape `(batch, src_len)`
                - **encoder_embedding** (Tensor): the (scaled) embedding lookup
                  of shape `(batch, src_len, embed_dim)`
                - **encoder_states** (List[Tensor]): all intermediate
                  hidden states of shape `(src_len, batch, embed_dim)`.
                  Only populated if *return_all_hiddens* is True.
        """
        x, encoder_embedding = self.forward_embedding(src_tokens)

        # B x T x C -> T x B x C
        x = x.transpose(0, 1)

        # compute padding mask
        encoder_padding_mask = src_tokens.eq(self.padding_idx)
        encoder_states = [] if return_all_hiddens else None

        # encoder layers
        for layer in self.layers:
            x = layer(x, encoder_padding_mask)
            if return_all_hiddens:
                assert encoder_states is not None
                encoder_states.append(x)

        if self.layer_norm is not None:
            x = self.layer_norm(x)

        encoder_last_hidden = x
        adapter_last_hidden = None
        hidden_after_proj = None
        subtract = None
        if self.bn_encoder_output:
            # encoder_padding_mask # [B, L]
            # x: [L, B, H]
            sentence_x = x.transpose(0, 1) # [B, L, H]
            if encoder_padding_mask is not None:
                bool_encoder_padding_mask = encoder_padding_mask.to(torch.bool)
                sentence_x *= ~bool_encoder_padding_mask.unsqueeze(-1)
            sentence_x = sentence_x.mean(dim=1, keepdims=False) # [B, H]
            mean, var = getMeanAndVar(sentence_x)
            normalized_x = normalizedFn(
                x.transpose(0, 1),
                mean.unsqueeze(1), var.unsqueeze(1)
            )
            x = normalized_x.transpose(0, 1) # [B, L, H] -> [L, B, H]
        elif self.proj_encoder_output:
            # encoder_padding_mask # [B, L]
            # x: [L, B, H]
            sentence_x = x.transpose(0, 1) # [B, L, H]
            if encoder_padding_mask is not None:
                bool_encoder_padding_mask = encoder_padding_mask.to(torch.bool)
                sentence_x *= ~bool_encoder_padding_mask.unsqueeze(-1)
            sentence_x = sentence_x.mean(dim=1, keepdims=False) # [B, H]
            normalized_x = None
            for (tokenId, component_name) in self.component_id2name.items():
                tokenId = int(tokenId)
                keyToken = src_tokens[:, -1]
                mask = torch.eq(
                    keyToken, tokenId
                ).unsqueeze(0).unsqueeze(-1)
                masked_x = mask * x
                component = self.__dict__[component_name]
                result, subtract = lir(
                    masked_x.transpose(0, 1),
                    sentence_x,
                    component
                )
                result = result.transpose(0, 1)
                result = result * mask
                if normalized_x is None:
                    normalized_x = result
                else:
                    normalized_x += result
            x = normalized_x

            if self.ln_after_proj:
                x = self.layer_norm_after_proj(x)
            hidden_after_proj = x

        if getattr(self, "adapter", None) is not None:
            adapter_last_hidden = self.adapter(src_tokens, x)
            extended_subtract = subtract.unsqueeze(0).repeat([adapter_last_hidden.size(0), 1, 1]) # [T, B, C]
            out_tensor = self.fuse_lang_specific_agnostic_fn(
                lang_specific_hidden=extended_subtract,
                lang_agnostic_hidden=adapter_last_hidden
            )
        else:
            out_tensor = encoder_last_hidden
        encoder_doc_state = self.select_doc_state(src_tokens, encoder_last_hidden)
        extra = {"encoder_doc_state": encoder_doc_state}

        if subtract is not None:
            extra["subtract"] = subtract

        if hidden_after_proj is not None:
            proj_doc_state = self.select_doc_state(src_tokens, hidden_after_proj)
            extra['proj_doc_state'] = proj_doc_state

        if adapter_last_hidden is not None:
            adapter_doc_state = self.select_doc_state(src_tokens, adapter_last_hidden)
            extra["adapter_doc_state"] = adapter_doc_state

        encoder_out = EncoderOut(
            encoder_out=out_tensor,  # T x B x C
            encoder_padding_mask=encoder_padding_mask,  # B x T
            encoder_embedding=encoder_embedding,  # B x T x C
            encoder_states=encoder_states,  # List[T x B x C]
            src_tokens=src_tokens,
            src_lengths=src_lengths,
        )
        return encoder_out, extra

