from typing import List

import torch
from torch.nn import CrossEntropyLoss
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast, Text2TextGenerationPipeline
from transformers.file_utils import PaddingStrategy
from transformers.modeling_outputs import Seq2SeqLMOutput

from translation_models import TranslationModel, ScoringModel


class MbartTranslationModel(TranslationModel):

    def __init__(self, model_name_or_path: str, src_lang: str, tgt_lang: str, *args, **kwargs):
        self.model_name_or_path = model_name_or_path
        self.src_lang = src_lang
        self.tgt_lang = tgt_lang
        model = MBartForConditionalGeneration.from_pretrained(model_name_or_path)
        tokenizer = MBart50TokenizerFast.from_pretrained(model_name_or_path)
        tokenizer.src_lang = src_lang
        tokenizer.tgt_lang = tgt_lang
        self.pipeline = Text2TextGenerationPipeline(
            model=model,
            tokenizer=tokenizer,
            *args, **kwargs,
        )

    @property
    def is_to_many_model(self):
        return "-to-many-" in self.model_name_or_path

    def translate(self, sentences: List[str], beam: int = 5, **kwargs) -> List[str]:
        results = self.pipeline(
            sentences,
            num_beams=beam,
            forced_bos_token_id=(self.pipeline.tokenizer.lang_code_to_id[self.tgt_lang] if self.is_to_many_model else None),
        )
        return [result["generated_text"] for result in results]

    def __str__(self):
        return self.model_name_or_path


class MbartScoringModel(ScoringModel, MbartTranslationModel):

    @torch.no_grad()
    def score(self, source_sentences: List[str], hypothesis_sentences: List[str], batch_size=2) -> List[float]:

        def batch(iterable):
            l = len(iterable)
            for ndx in range(0, l, batch_size):
                yield iterable[ndx:min(ndx + batch_size, l)]

        padding_strategy = PaddingStrategy.MAX_LENGTH if batch_size > 1 else PaddingStrategy.DO_NOT_PAD

        scores = []
        for source_batch, hypothesis_batch in zip(batch(source_sentences), batch(hypothesis_sentences)):
            inputs = self.pipeline.tokenizer._batch_encode_plus(source_batch, return_tensors="pt", padding_strategy=padding_strategy)
            with self.pipeline.tokenizer.as_target_tokenizer():
                labels = self.pipeline.tokenizer._batch_encode_plus(hypothesis_batch, return_tensors="pt", padding_strategy=padding_strategy)
            inputs["labels"] = labels["input_ids"]
            inputs = self.pipeline.ensure_tensor_on_device(**inputs)
            output: Seq2SeqLMOutput = self.pipeline.model(**inputs)
            for i in range(len(source_batch)):
                loss = CrossEntropyLoss()(output.logits[i].view(-1, self.pipeline.model.config.vocab_size), inputs["labels"][i].view(-1))
                score = -loss.item()
                scores.append(score)
            del output
        assert len(scores) == len(source_sentences)
        return scores
