import torch
import torch.nn as nn
import torch.nn.functional as F
from peft.tuners.lora.bnb import Linear4bit as LoraLinear4bit
from peft.tuners.lora.bnb import Linear8bitLt as LoraLinear8bitLt
from peft.tuners.lora.layer import Linear as LoraLinear
from transformers.modeling_outputs import (
    QuestionAnsweringModelOutput,
    SequenceClassifierOutput,
)
from transformers.models.m2m_100.modeling_m2m_100 import (
    M2M100Encoder,
    M2M100Model,
)
from transformers.pytorch_utils import Conv1D
from trident.core.module import TridentModule
from trident.utils.logging import get_logger
from typing import cast, Optional
from src.modules.functional import pooling
from transformers.models.llama import LlamaModel
from transformers.modeling_outputs import BaseModelOutputWithPooling
from peft.peft_model import PeftModelForFeatureExtraction
from functools import partial

log = get_logger(__name__)

torch.set_float32_matmul_precision("medium")
torch.use_deterministic_algorithms(False)


def get_leaf_modules(
    module: nn.Module, parent_name: str = "", kind: str = "linear"
) -> dict[str, nn.Module]:
    """Recursively collect all leaf linear layers of a input `nn.Module`."""
    module_dict = {}
    has_children = False
    for name, child in module.named_children():
        has_children = True
        full_name = f"{parent_name}.{name}" if parent_name else name
        module_dict.update(get_leaf_modules(child, full_name, kind))
    if kind == "linear":
        # Conv1D for GPT2
        if not has_children and isinstance(module, (nn.Linear, Conv1D)):
            module_dict[parent_name] = module
    elif kind == "lora":
        if has_children and isinstance(
            module, (LoraLinear, LoraLinear4bit, LoraLinear8bitLt)
        ):
            module_dict[parent_name] = module
    return module_dict


def FVU(x: torch.Tensor, x_hat: torch.Tensor, mse_loss: torch.Tensor):
    """Fraction of Variance Unexplained"""
    d_model = x.shape[-1]
    x = x.view(-1, d_model)
    x_hat = x_hat.view(-1, d_model)

    # compute variance of the original activations
    variance = (x - x.mean(dim=0)).pow(2).mean()

    # return ratio of the MSE to the variance of the original activations
    return mse_loss / variance


class NLLBEncoder(nn.Module):
    def __init__(
        self,
        nllb: M2M100Model,
        pooling_strategy: str = "mean",
        padding_side: str = "right",
        *args,
        **kwargs,
    ) -> None:
        super().__init__(*args, **kwargs)
        self.nllb: M2M100Encoder = (
            nllb.encoder if isinstance(nllb, M2M100Model) else nllb
        )
        self.pooling_strategy = pooling_strategy
        self.pooling_fn = partial(
            getattr(pooling, self.pooling_strategy), padding_side=padding_side
        )

    def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor):
        nllb_embeds_NLD = self.nllb(
            input_ids=input_ids,
            attention_mask=attention_mask,
        ).last_hidden_state
        pooler_output = self.pooling_fn(nllb_embeds_NLD, attention_mask)
        return BaseModelOutputWithPooling(
            last_hidden_state=nllb_embeds_NLD,
            pooler_output=pooler_output,
        )


class NLLBLlamaEncoder(nn.Module):
    def __init__(
        self,
        llama: LlamaModel | PeftModelForFeatureExtraction,
        nllb: M2M100Model,
        pooling_strategy: str = "mean",
        *args,
        **kwargs,
    ) -> None:
        super().__init__(*args, **kwargs)
        self.llama = llama
        self.nllb: M2M100Encoder = nllb.encoder
        self.up_proj = nn.Linear(nllb.config.hidden_size, llama.config.hidden_size)
        self.pooling_strategy = pooling_strategy
        self.pooling_fn = getattr(pooling, self.pooling_strategy)

        for p in self.nllb.parameters():
            p.requires_grad = False

    def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor):
        with torch.inference_mode():
            nllb_embeds_MLD = self.nllb(
                input_ids=input_ids,
                attention_mask=attention_mask,
            ).last_hidden_state
            nllb_embeds_MLD = self.up_proj(nllb_embeds_MLD)
        outputs = self.llama(
            inputs_embeds=nllb_embeds_MLD,
            attention_mask=attention_mask,
        )
        pooler_output = self.pooling_fn(outputs.last_hidden_state, attention_mask)
        return BaseModelOutputWithPooling(
            last_hidden_state=outputs.last_hidden_state, pooler_output=pooler_output
        )


class DistillationModule(TridentModule):
    def __init__(
        self,
        pooling_strategy: str = "mean",
        # one of "last", "all"
        distillation_strategy: str = "last",
        *args,
        **kwargs,
    ) -> None:
        # logs all configs to self.hyperparams
        super().__init__(*args, **kwargs)
        self.pooling_strategy = pooling_strategy
        self.pooling_fn = getattr(pooling, self.pooling_strategy)
        assert (
            self.pooling_fn is not None
        ), "`self.pooling_strategy` must be one of mean, eos, cls"
        self.distillation_strategy = distillation_strategy
        self._initialize_model()

    def _initialize_model(
        self,
    ):
        # TODO: this probably doesn't work correctly after refactor any longer
        if self.distillation_strategy == "all":
            peft_modules = get_leaf_modules(self.model.llama, kind="lora")
            self.base_modules = get_leaf_modules(self.model.llama, kind="linear")
            self.base_out_dict = self._setup_hooks(self.base_modules)
            self.peft_modules = {
                k[len("base_model.model.") :]: v for k, v in peft_modules.items()
            }
            self.peft_out_dict = self._setup_hooks(self.peft_modules)
        else:
            self.base_modules = None
            self.base_out_dict = None
            self.peft_modules = None
            self.peft_out_dict = None

    def _setup_hooks(self, modules: dict[str, nn.Module]):
        def set_hook(
            # in_dict: dict[str, torch.Tensor],
            out_dict: dict[str, torch.Tensor],
            layer_name: str,
        ):
            def hook(_, __, output):
                out_dict[layer_name] = output

            return hook

        out_dict: dict[str, torch.Tensor] = {}
        for name, parameter in modules.items():
            parameter.register_forward_hook(set_hook(out_dict, name))
        return out_dict

    def forward(self, batch: dict):
        # use constructed model during validation
        with torch.inference_mode():
            nllb_embeds_MLD = self.model.nllb(
                input_ids=batch["nllb_input_ids"],
                attention_mask=batch["nllb_attention_mask"],
            ).last_hidden_state
            nllb_embeds_MLD = self.model.up_proj(nllb_embeds_MLD)
        return self.model.llama(
            inputs_embeds=nllb_embeds_MLD,
            attention_mask=batch.get("nllb_attention_mask"),
            labels=batch.get("labels"),
        )

    def training_step(  # type: ignore
        self, batch: dict[str, torch.Tensor], batch_idx: int = 0
    ) -> torch.Tensor:
        # original model input
        self.model.llama.disable_adapter_layers()
        with torch.inference_mode():
            # through Llama 3 w/o LoRA
            # potentially task fine-tuned
            # base_outputs.last_hidden_state is (N, L, D)
            llama_outputs = self.model.llama(
                input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]
            )

            # self.model.nllb is encoder of NLLB
            nllb_embeds_NKd = self.model.nllb(
                input_ids=batch["nllb_input_ids"],
                attention_mask=batch["nllb_attention_mask"],
            ).last_hidden_state  # (M, L, d)
        # up-projection
        # nllb_embeds_NKd ->  nllb_embeds_NKD
        nllb_embeds_NKD = self.model.up_proj(nllb_embeds_NKd)

        # fetch model outputs in separate dict to avoid overwriting
        if self.base_out_dict is not None:
            base_out_dict = {k: v for k, v in self.base_out_dict.items()}
        else:
            base_out_dict = None

        self.model.llama.enable_adapter_layers()
        nllb_llama_outputs = self.model.llama(
            inputs_embeds=nllb_embeds_NKD, attention_mask=batch["nllb_attention_mask"]
        )

        # simple last layer distillation
        if base_out_dict is None:
            nllb_hidden_states = self.pooling_fn(
                nllb_llama_outputs.last_hidden_state,
                attention_mask=batch["nllb_attention_mask"],
            )
            # CLS, EOS, Mean pooling
            llama_hidden_states = self.pooling_fn(
                llama_outputs.last_hidden_state, attention_mask=batch["attention_mask"]
            )
            mse_loss = F.mse_loss(nllb_hidden_states, llama_hidden_states)
            with torch.no_grad():
                fvu_loss = FVU(
                    x=nllb_hidden_states, x_hat=llama_hidden_states, mse_loss=mse_loss
                )
            self.log("train/mse", mse_loss)
            self.log("train/fvu", fvu_loss)
            return mse_loss
        # layer-wise distillation
        else:
            losses = []
            device = batch["input_ids"].device
            N = torch.arange(batch["input_ids"].shape[0], device=device)
            M = torch.arange(batch["nllb_input_ids"].shape[0], device=device)
            for name, tensor in base_out_dict.items():
                peft_tensor = cast(dict, self.peft_out_dict)[name]
                mse_loss = F.mse_loss(
                    peft_tensor[M, batch["nllb_eos_token_idx"]],
                    tensor[N, batch["eos_token_idx"]],
                )
                with torch.no_grad():
                    fvu_loss = FVU(x=peft_tensor, x_hat=tensor, mse_loss=mse_loss)
                self.log(f"train/{name}/mse_loss", mse_loss)
                self.log(f"train/{name}/fvu_loss", fvu_loss)
                losses.append(mse_loss)
            loss = torch.stack(losses).mean()
            return loss


class AutoModule(TridentModule):
    def __init__(
        self,
        # simplifies checkpointing to align with custom validation
        nllb_ckpt: None | str = None,
        checkpoint_path: Optional[str] = None,
        save_checkpoint_on_validation_dir: Optional[str] = None,
        pooling_strategy: str = "mean",
        padding_side: str = "right",
        needs_prefix: bool = True,
        *args,
        **kwargs,
    ):
        super().__init__(*args, **kwargs)
        self.checkpoint_path = checkpoint_path
        self._loaded = False
        self.is_nllb = "NLLB" in str(type(self.model))
        if self.is_nllb:
            self.batch_prefix = "nllb_"
            if isinstance(nllb_ckpt, str):
                ckpt = torch.load(nllb_ckpt, map_location="cuda:0")["state_dict"]
                self.load_state_dict(ckpt, strict=True)
                log.info(f"Successfully restored {nllb_ckpt.split('/')[-1]} checkpoint")
            else:
                log.info("No checkpoint restored")
        else:
            self.batch_prefix = ""

        if not needs_prefix:
            self.batch_prefix = ""

        self.save_checkpoint_on_validation_dir = save_checkpoint_on_validation_dir
        if self.save_checkpoint_on_validation_dir is not None:
            self._validation_epoch = 0

        self.pooling_strategy = pooling_strategy
        self.pooling_fn = partial(
            getattr(pooling, self.pooling_strategy), padding_side=padding_side
        )

    def on_validation_end(self) -> None:
        super().on_validation_end()
        if self.save_checkpoint_on_validation_dir:
            from pathlib import Path

            path = Path(self.save_checkpoint_on_validation_dir).joinpath(
                f"validation-epoch={self._validation_epoch}.ckpt"
            )
            self.trainer.save_checkpoint(path, weights_only=True)
            self._validation_epoch += 1

    # def setup(self, stage):
    #     if not self._loaded and self.checkpoint_path is not None:
    #         ckpt = torch.load(self.checkpoint_path, map_location="cuda:0")["state_dict"]
    #         self.load_state_dict(ckpt, strict=True)
    #         log.info("Successfully loaded checkpoint path")
    #         self._loaded = True


class AutoModuleForSequenceClassification(AutoModule):
    def __init__(self, num_labels: int = 3, *args, **kwargs):
        super().__init__(*args, **kwargs)
        if self.is_nllb:
            if hasattr(self.model, "llama"):
                self.head = nn.Linear(
                    self.model.llama.config.hidden_size, num_labels, bias=False
                )
            else:
                self.head = nn.Linear(
                    self.model.nllb.config.hidden_size, num_labels, bias=False
                )
        else:
            self.head = nn.Linear(self.model.config.hidden_size, num_labels, bias=False)

    def forward(self, batch, *args, **kwargs):
        outputs = self.model(
            input_ids=batch[f"{self.batch_prefix}input_ids"],
            attention_mask=batch[f"{self.batch_prefix}attention_mask"],
        )
        hidden_states = outputs.last_hidden_state
        sequence_embeds = self.pooling_fn(
            hidden_states, batch[f"{self.batch_prefix}attention_mask"]
        )
        logits = self.head(sequence_embeds)
        loss = F.cross_entropy(logits, batch["labels"])
        return {
            "loss": loss,  # type: ignore
            "logits": logits,
            "sequence_embeds": sequence_embeds,
        }


class AutoModuleForSequenceClassificationDistillation(
    AutoModuleForSequenceClassification
):
    def __init__(self, ckpt: str, *args, **kwargs):
        super().__init__(*args, **kwargs)
        ckpt_ = {
            "weight": torch.load(ckpt, map_location="cuda:0")["state_dict"][
                "head.weight"
            ]
        }
        assert isinstance(self.head, nn.Module)
        self.head.load_state_dict(ckpt_, strict=True)

    def training_step(self, batch: dict, batch_idx: int) -> dict:
        out = self(batch)
        loss = F.mse_loss(out["sequence_embeds"], batch["sequence_embeds"])
        self.log("train/mse", loss)
        return {"loss": loss}


class AutoModuleForMultipleChoice(AutoModule):
    def __init__(self, num_choices: int = 4, *args, **kwargs):
        super().__init__(*args, **kwargs)
        # head as per XLMRobertaForMultipleChoice
        if self.is_nllb:
            if hasattr(self.model, "llama"):
                self.head = nn.Linear(
                    self.model.llama.config.hidden_size, num_choices, bias=False
                )
            else:
                self.head = nn.Linear(
                    self.model.nllb.config.hidden_size, num_choices, bias=False
                )
        else:
            self.head = nn.Linear(self.model.config.hidden_size, 1, bias=False)
        self.num_choices = 4

    def forward(self, batch, *args, **kwargs):
        input_ids = batch[f"{self.batch_prefix}input_ids"]
        attention_mask = batch[f"{self.batch_prefix}attention_mask"]
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
        )
        flat_hidden_states = outputs.last_hidden_state
        sequence_embeds = self.pooling_fn(
            flat_hidden_states, batch[f"{self.batch_prefix}attention_mask"]
        )
        logits = self.head(sequence_embeds)
        reshaped_logits = logits.view(-1, self.num_choices)

        loss = F.cross_entropy(reshaped_logits, batch["labels"])
        return {"loss": loss, "logits": reshaped_logits}


class AutoModuleForMultipleChoice2(AutoModule):
    def __init__(self, num_choices: int = 4, *args, **kwargs):
        super().__init__(*args, **kwargs)
        # head as per XLMRobertaForMultipleChoice
        if self.is_nllb:
            if hasattr(self.model, "llama"):
                self.head = nn.Linear(
                    self.model.llama.config.hidden_size, 1, bias=False
                )
            else:
                self.head = nn.Linear(self.model.nllb.config.hidden_size, 1, bias=False)
        else:
            self.head = nn.Linear(self.model.config.hidden_size, 1, bias=False)
        self.num_choices = num_choices

    def forward(self, batch, *args, **kwargs):
        input_ids = batch["input_ids"]
        attention_mask = batch["attention_mask"]
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
        )
        # Assuming batch and outputs are defined and contain the necessary tensors
        # outputs.last_hidden_state shape: (N, L, D)
        # batch["mean_mask"] shape: (N, L, C)

        # Extract the hidden states from the model output
        hidden_states = outputs.last_hidden_state

        # Extract the mask from the batch
        mask = batch["mean_mask"]

        # Use Einstein summation to compute the weighted sum of hidden states according to the mask
        # This results in a tensor of shape (N, C, D)
        choice_embeds = torch.einsum("nlc,nld->ncd", mask, hidden_states)

        # Compute the sum of the mask along the sequence length L
        # This results in a tensor of shape (N, C)
        mask_sum = mask.sum(1)

        # Divide the summed embeddings by the mask sum to get the average embeddings
        # Ensure broadcasting by adding a singleton dimension to mask_sum
        choice_embeds = choice_embeds / mask_sum[:, :, None]
        logits = self.head(choice_embeds).view(-1, self.num_choices)
        loss = F.cross_entropy(logits, batch["labels"])
        return {"loss": loss, "logits": logits, "choice_embeds": choice_embeds}


class AutoModuleForMultipleChoiceDistillation(AutoModuleForMultipleChoice2):
    def __init__(self, ckpt: str, *args, **kwargs):
        super().__init__(*args, **kwargs)
        ckpt_ = {
            "weight": torch.load(ckpt, map_location="cuda:0")["state_dict"][
                "head.weight"
            ]
        }
        assert isinstance(self.head, nn.Module)
        self.head.load_state_dict(ckpt_, strict=True)

    def training_step(self, batch: dict, batch_idx: int) -> dict:
        out = self(batch)
        loss = F.mse_loss(out["choice_embeds"], batch["choice_embeds"])
        self.log("train/mse", loss)
        return {"loss": loss}


class AutoModuleForQuestionAnswering(AutoModule):
    def forward(self, batch, *args, **kwargs):
        """
        Adapated from XLMRobertaForQuestionAnswering
        """
        outputs = self.model(
            input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]
        )
        hidden_states = outputs.last_hidden_state
        sequence_embeds = self.pooling_fn(hidden_states, batch["attention_mask"])
        logits = self.head(sequence_embeds)

        logits = self.qa_outputs(sequence_embeds)
        start_logits, end_logits = logits.split(1, dim=-1)
        start_logits = start_logits.squeeze(-1).contiguous()
        end_logits = end_logits.squeeze(-1).contiguous()

        total_loss = None
        start_positions = batch["start_positions"]
        end_positions = batch["end_positions"]
        if start_positions is not None and end_positions is not None:
            # If we are on multi-GPU, split add a dimension
            if len(start_positions.size()) > 1:
                start_positions = start_positions.squeeze(-1)
            if len(end_positions.size()) > 1:
                end_positions = end_positions.squeeze(-1)
            # sometimes the start/end positions are outside our model inputs, we ignore these terms
            ignored_index = start_logits.size(1)
            start_positions = start_positions.clamp(0, ignored_index)
            end_positions = end_positions.clamp(0, ignored_index)

            start_loss = F.cross_entropy(start_logits, start_positions)
            end_loss = F.cross_entropy(end_logits, end_positions)
            total_loss = (start_loss + end_loss) / 2

        return QuestionAnsweringModelOutput(
            loss=total_loss,
            start_logits=start_logits,
            end_logits=end_logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
