from typing import Any, cast
from transformers import NllbTokenizerFast
from transformers.models.m2m_100.modeling_m2m_100 import M2M100ForConditionalGeneration
from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
from trident.core.module import TridentModule
from peft import get_peft_model
from functools import partial
import torch
import torch.nn as nn
import torch.nn.functional as F
from omegaconf.dictconfig import DictConfig
from pathlib import Path
import pandas as pd
from hydra.utils import instantiate
from copy import deepcopy


class TranslationModule(TridentModule):
    def __init__(self, generate_kwargs: dict | DictConfig, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.generate_kwargs = instantiate(generate_kwargs)

    def forward(self, batch: dict[str, torch.Tensor]):
        b = {k: v for k, v in batch.items() if isinstance(v, torch.Tensor)}
        return {"preds": self.model.generate(**b, **self.generate_kwargs)}


def decode(
    outputs: dict[str, Any],
    batch: dict[str, Any],
    tokenizer: PreTrainedTokenizerFast,
    text: list[str],
    decode_kwargs: dict = {"skip_special_tokens": True},
    *args,
    **kwargs,
):
    # the translations: list[str] are
    # K: translations per text (i.e., N // len(text))
    # [text1_1, ..., text_1_K, ..., text_k_1, ..., text_k_N] aligned
    translations: list[str] = tokenizer.batch_decode(outputs["preds"], **decode_kwargs)
    N = len(translations)
    assert N % len(text) == 0, "Number of text doesn't align with translations"
    per_text_N = N // len(text)
    for i, k in enumerate(text):
        outputs[k] = translations[i * per_text_N : (i + 1) * per_text_N]
        outputs[f"{k}_source"] = batch[f"{k}_source"]

    return outputs


def store_translations(
    outputs: dict[str, list[str]],
    text: str | list[str],
    others: list[str],
    filename: str,
    dir_: str,
    *args,
    **kwargs,
):
    if isinstance(text, str):
        text = [text]
    dir_path = Path(dir_)
    dir_path.mkdir(parents=True, exist_ok=True)
    filepath = dir_path.joinpath(filename)
    dico = {k: outputs[k] for k in text}
    for k in text:
        dico[f"{k}_source"] = outputs[f"{k}_source"]
    for other in others:
        dico[other] = outputs[other]
    df = pd.DataFrame.from_dict(dico)
    df.to_parquet(str(filepath))


class CollatorForTranslation:
    def __init__(
        self,
        tokenizer: PreTrainedTokenizerFast,
        tokenize_kwargs: dict[str, Any],
        columns: dict[str, list[str]],
    ) -> None:
        self.tokenizer = tokenizer
        self.tokenize_kwargs = tokenize_kwargs
        self.columns = columns

    def __call__(self, inputs: list[dict[str, str]], *args: Any, **kwds: Any) -> Any:
        batch = {}

        texts = []
        for text_column in self.columns["text"]:
            text: list[str] = [line[text_column] for line in inputs]
            batch[f"{text_column}_source"] = text
            texts.extend(text)
        batch_ = self.tokenizer(texts, **self.tokenize_kwargs).data
        for k, v in batch_.items():
            batch[k] = v
        if (others := self.columns.get("others", None)) is not None:
            for column in others:
                batch[column] = [line[column] for line in inputs]
        return batch


def get_lang_to_id_code(tokenizer: NllbTokenizerFast, lang_code: str):
    return tokenizer.lang_code_to_id[lang_code]


def add_source(texts: list[str]):
    out = []
    for t in texts:
        out.append(t)
        out.append(f"{t}_source")
    return out


def get_translation_training_tokenizer(src_lang: str = "eng_Latn", *args, **kwargs):
    from transformers import AutoTokenizer, AutoConfig, AutoModel

    model = AutoModel.from_pretrained("facebook/nllb-200-distilled-600M")
    nllb_tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M")
    nllb_tokenizer.add_special_tokens(
        {
            "additional_special_tokens": [
                "bzd_Latn",
                "cni_Latn",
                "hch_Latn",
                "nah_Latn",
                "oto_Latn",
                "shp_Latn",
                "tar_Latn",
            ]
        }
    )
    # for i, lang in enumerate(
    #     [
    #         "bzd_Latn",
    #         "cni_Latn",
    #         "hch_Latn",
    #         "nah_Latn",
    #         "oto_Latn",
    #         "shp_Latn",
    #         "tar_Latn",
    #     ]
    # ):
    #     nllb_tokenizer.lang_code_to_id[lang] = N + i
    nllb_tokenizer.src_lang = src_lang
    return nllb_tokenizer


# nllb_tokenizer.src_lang = "hch_Latn"
# nllb_tokenizer.decode(nllb_tokenizer("This is a test")["input_ids"])
# model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M")
# encoder_ie = model.model.encoder.get_input_embeddings()
# decoder_ie = model.model.decoder.get_input_embeddings()
# model.model.encoder.embed_tokens.weight is model.model.decoder.embed_tokens.weight
# w = model.model.encoder.embed_tokens.weight.clone()
# w2 = model.model.encoder.embed_tokens.weight.clone()
# w2[-7:] = w[nllb_tokenizer.lang_code_to_id["spa_Latn"]]
# model.model.decoder.embed_tokens.weight.shape


class EmbedForward(nn.Module):
    def __init__(
        self,
        embeddings: nn.Embedding,
        new_lang_tokens: nn.Parameter,
        *args,
        **kwargs,
    ) -> None:
        super().__init__(*args, **kwargs)
        self.embeddings = embeddings
        self.new_lang_tokens = new_lang_tokens
        self.embeddings.weight.requires_grad = False
        self.new_lang_tokens = new_lang_tokens

    def __call__(self, input_ids: torch.Tensor):
        token_embeds = torch.vstack(
            [
                self.embeddings.weight[: -len(self.new_lang_tokens)],
                self.new_lang_tokens,
            ]
        )
        return F.embedding(
            input_ids,
            weight=token_embeds,
            padding_idx=self.embeddings.padding_idx,
            max_norm=self.embeddings.max_norm,
            scale_grad_by_freq=self.embeddings.scale_grad_by_freq,
        )


class NLLBTranslationModule(TridentModule):
    def __init__(self, peft_config: DictConfig, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.model = cast(M2M100ForConditionalGeneration, self.model)

        # initialize new language token embeddings with English
        # self.original_embeds = self.model.model.encoder.embed_tokens.weight.data
        self.tokenizer = get_translation_training_tokenizer()
        self.model.resize_token_embeddings(len(self.tokenizer))
        eng_Latn_id = self.tokenizer.encode("spa_Latn", add_special_tokens=False)[0]
        self.model.get_input_embeddings().weight.data[-7:] = (
            self.model.get_input_embeddings().weight.data[eng_Latn_id]
        )
        self.model.model.encoder.embed_tokens = deepcopy(self.model.model.shared)
        # self.new_lang_embed_tokens = self.model.model.encoder.embed_tokens.weight.data[
        #     -7:
        # ].clone()
        # self.encoder_embed_tokens = self.model.model.encoder.embed_tokens
        # self.encoder_embed_tokens.weight.requires_grad = False
        # self.new_lang_embed_tokens[:] = (
        #     self.model.model.encoder.embed_tokens.weight.data[
        #         self.tokenizer.lang_code_to_id["eng_Latn"]
        #     ].clone()
        # )
        # self.new_lang_embed_tokens = nn.Parameter(
        #     self.new_lang_embed_tokens, requires_grad=True
        # )
        # monkey patch embed_tokens
        # self.model.model.encoder.embed_tokens = EmbedForward(  # type: ignore
        #     embeddings=self.model.model.encoder.embed_tokens,
        #     new_lang_tokens=self.new_lang_embed_tokens,
        # )
        # self.model.model.encoder = get_peft_model(
        #     self.model.model.encoder, instantiate(peft_config)
        # )
        # for p in self.model.model.encoder.parameters():
        #     p.requires_grad = False
        for p in self.model.model.decoder.parameters():
            p.requires_grad = False

        for p in self.model.model.encoder.embed_tokens.parameters():
            p.requires_grad = True
        # self.new_lang_embed_tokens.requires_grad = True
        # for p in self.model.get_input_embeddings().parameters():
        #     p.requires_grad = True

    def on_save_checkpoint(self, checkpoint: dict[str, Any]) -> None:
        m = deepcopy(self.model)
        # m.model.encoder = m.model.encoder.merge_and_unload()
        # m.model.encoder.embed_tokens = deepcopy(self.encoder_embed_tokens)
        # m.model.encoder.embed_tokens.weight.data[-len(self.new_lang_embed_tokens) :] = (
        #     self.new_lang_embed_tokens
        # )
        checkpoint["state_dict"] = m.state_dict()
        return super().on_save_checkpoint(checkpoint)
