import torch
import json
import torch.nn as nn

from examples.summarization.modules.hubAdapter import Adapter, GatedAdapter
from typing import Dict
from torch import Tensor
from fairseq.models.fairseq_encoder import EncoderOut


class intergratedAdapter(nn.Module):
    def __init__(
        self, args, double_input=False, iadapter_config=None,
    ):
        super().__init__()
        if iadapter_config is None:
            iadapter_config = args.iadapter_config
        self.config = json.load(open(iadapter_config, "r"))
        self.layer_configs = self.config["layers"]
        self.adapters = {}
        self.double_input = double_input
        for layer_config in self.layer_configs:
            self.build_layer(
                args, layer_config
            )
        
    def build_layer(self, args, layer_config):
        adapter_type = layer_config["adapter_type"]
        input_dim = args.encoder_embed_dim * 2 if self.double_input else args.encoder_embed_dim
        output_dim = args.encoder_embed_dim
        for name in layer_config['adapter_name']:
            if adapter_type == "TransformerEncoderLayer":
                # adapter = TransformerEncoderLayer(args)
                raise NotImplementedError
            elif adapter_type == "pfeiffer":
                adapter = Adapter(
                    input_dim,
                    output_size=output_dim,
                    down_sample=input_dim // 2,
                    add_layer_norm_before=False,
                    add_layer_norm_after=False,
                    residual_before_ln=True
                )
            elif adapter_type == "final_ln":
                adapter = Adapter(
                    input_dim,
                    output_size=output_dim,
                    down_sample=input_dim // 2,
                    add_layer_norm_before=False,
                    add_layer_norm_after=False,
                    residual_before_ln=False,
                    final_ln=True
                )
            elif adapter_type == "pre_ln":
                adapter = Adapter(
                    input_dim,
                    output_size=output_dim,
                    down_sample=input_dim // 2,
                    add_layer_norm_before=True,
                    add_layer_norm_after=False,
                    residual_before_ln=False
                )
            elif adapter_type == "gated_final_ln":
                adapter = GatedAdapter(
                    input_dim,
                    output_size=output_dim,
                )
            else:
                raise NotImplementedError()
            self.__setattr__("{}_adapter".format(name), adapter)
            self.adapters[name] = adapter
        
    def layer_forward(
        self, src_tokens, x, layer_config, encoder_padding_mask,
        input_b=None
    ):
        """
        Input:
            src_tokens: [B, T]
            x: [T, B, C]
            input_b: None or same shape as x (If self.double_input is False)
        Return:
            [T, B, C]
        """
        if layer_config.get("routing", None) is None:
            adapter_name = layer_config['adapter_name'][0]
            adapter = self.adapters[adapter_name]
            if self.double_input:
                assert input_b is not None, "Please provide input_b if self.double_input is True"
                tmp = torch.cat([x, input_b], dim=-1)
                residual_input = x
            else:
                tmp = x
                residual_input = x
            result, _, _ = adapter(
                tmp,
                residual_input
            )
        else:
            result = None
            for (tokenId, adapter_name) in layer_config["routing"].items():
                tokenId = int(tokenId)
                keyToken = src_tokens[:, -1]
                mask = torch.eq(
                    keyToken, tokenId
                ).unsqueeze(0).unsqueeze(-1)
                adapter = self.adapters[adapter_name]
                if self.double_input:
                    assert input_b is not None, "Please provide input_b if self.double_input is True"
                    tmp = torch.cat([x, input_b], dim=-1)
                    residual_input = x
                else:
                    tmp = x
                    residual_input = x
                adapter_result, _, _ = adapter(
                    tmp,
                    residual_input
                )
                adapter_result = adapter_result * mask
                if result is None:
                    result = adapter_result
                else:
                    result += adapter_result

        return result

    def forward(self, src_tokens, encoder_out_tuple, encoder_out_tensor, input_b=None):
        """
        Input:
            src_tokens: Tensor
            encoder_out_tuple: EncoderOut
            encoder_out_tensor: Tensor
        Return:
            adapter_out: EncoderOut
        """
        x = encoder_out_tensor
        encoder_padding_mask = encoder_out_tuple.encoder_padding_mask
        for (li, layer_config) in enumerate(self.layer_configs):
            x = self.layer_forward(
                src_tokens, x, layer_config,
                encoder_padding_mask,
                input_b=input_b
            )
        adapter_out_tuple = EncoderOut(
            encoder_out=x,  # T x B x C
            encoder_padding_mask=encoder_out_tuple.encoder_padding_mask,  # B x T
            encoder_embedding=encoder_out_tuple.encoder_embedding,  # B x T x C
            encoder_states=encoder_out_tuple.encoder_states,  # List[T x B x C]
            src_tokens=encoder_out_tuple.src_tokens,
            src_lengths=encoder_out_tuple.src_lengths,
        )
        return adapter_out_tuple

    def decoder_forward(self, src_tokens, input_tensor):
        """
        Input:
            src_tokens: Tensor
            encoder_out_tensor: Tensor
        Return:
            result: Tensor
        """
        x = input_tensor
        for (li, layer_config) in enumerate(self.layer_configs):
            x = self.layer_forward(
                src_tokens, x, layer_config,
                encoder_padding_mask=None
            )
        return x

    def forward_torchscript(
        self, 
        net_input: Dict[str, Tensor],
        encoder_out: EncoderOut,
        extra = None
    ):
        """A TorchScript-compatible version of forward.

        Encoders which use additional arguments may want to override
        this method for TorchScript compatibility.
        """
        if torch.jit.is_scripting():
            return self.forward(
                src_tokens=net_input["src_tokens"],
                encoder_out_tuple=encoder_out,
                encoder_out_tensor=encoder_out.encoder_out
            )
        else:
            input_b = extra.get("subtract", None)
            if input_b is not None:
                input_b = input_b.unsqueeze(0).expand(encoder_out.encoder_out.size(0), -1, -1)
            adapter_net_input = {
                "src_tokens": net_input["src_tokens"],
                "encoder_out_tuple": encoder_out,
                "encoder_out_tensor": encoder_out.encoder_out,
                "input_b": input_b
            }
            return self.forward_non_torchscript(adapter_net_input)

    @torch.jit.unused
    def forward_non_torchscript(self, net_input: Dict[str, Tensor]):
        return self.forward(**net_input)
