import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Any, Dict
from transformers import LEDForConditionalGeneration

from torchfly.training import FlyModel
from torchfly.metrics import CategoricalAccuracy, Average, MovingAverage, Speed


class LEDGenerationFlyModel(FlyModel):
    def __init__(self, config):
        super().__init__(config)
        self.bart = LEDForConditionalGeneration.from_pretrained(
            config.task.pretrained_model, dropout=0.0, attention_dropout=0.1, activation_dropout=0.1
        )

        self.eval_loss_fct = nn.CrossEntropyLoss(ignore_index=-100, reduction="none")
        # configure metrics here
        self.configure_metrics()

    def configure_metrics(self):
        self.training_metrics = {
            "loss": MovingAverage(name="loss"),
        }
        self.evaluation_metrics = {
            "loss": Average(),
            "session1_loss": Average(),
            "session2_loss": Average(),
            "session3_loss": Average(),
            "session4_loss": Average(),
            "session5_loss": Average(),
        }

    def forward(self, batch: Dict[str, Any]) -> Dict[str, Any]:
        batch = batch[0]
        output = self.bart(
            input_ids=batch["encoder_input_ids"],
            decoder_input_ids=batch["decoder_input_ids"],
            attention_mask=batch["encoder_attention_mask"],
            labels=batch["target"],
            return_dict=True,
        )
        self.training_metrics["loss"](output.loss.item())
        return output

    def predict_step(self, batch):
        output = self.bart(
            input_ids=batch["encoder_input_ids"],
            decoder_input_ids=batch["decoder_input_ids"],
            attention_mask=batch["encoder_attention_mask"],
            return_dict=True,
        )

        batch_size = len(batch["session_id"])
        lm_loss = self.eval_loss_fct(output.logits.view(-1, output.logits.shape[-1]), batch["target"].view(-1)).view(
            batch_size, -1
        )
        lm_loss = lm_loss.sum(-1) / (batch["decoder_attention_mask"].sum(-1) + 1e-5)

        session_losses = lm_loss.tolist()

        for batch_idx in range(batch_size):
            sess_id = batch["session_id"][batch_idx]
            sess_loss = session_losses[batch_idx]
            if sess_id != "none":
                self.evaluation_metrics[f"session{sess_id}_loss"](sess_loss)

        lm_loss = lm_loss.mean()
        self.evaluation_metrics["loss"](lm_loss.item())

        return None

    def get_training_metrics(self) -> Dict[str, str]:
        loss = self.training_metrics["loss"].get_metric()
        metrics = {"loss": f"{loss:.4f}"}
        return metrics

    def get_evaluation_metrics(self) -> Dict[str, str]:
        loss = self.evaluation_metrics["loss"].get_metric()
        ppl = np.exp(loss)
        session1_loss = self.evaluation_metrics["session1_loss"].get_metric()
        session2_loss = self.evaluation_metrics["session2_loss"].get_metric()
        session3_loss = self.evaluation_metrics["session3_loss"].get_metric()
        session4_loss = self.evaluation_metrics["session4_loss"].get_metric()
        session5_loss = self.evaluation_metrics["session5_loss"].get_metric()
        ppl1 = np.exp(session1_loss)
        ppl2 = np.exp(session2_loss)
        ppl3 = np.exp(session3_loss)
        ppl4 = np.exp(session4_loss)
        ppl5 = np.exp(session5_loss)

        score = -ppl5

        metrics = {
            "loss": f"{loss:8.4f}",
            "ppl": f"{ppl:8.4f}",
            "s1_ppl": f"{ppl1:8.4f}",
            "s2_ppl": f"{ppl2:8.4f}",
            "s3_ppl": f"{ppl3:8.4f}",
            "s4_ppl": f"{ppl4:8.4f}",
            "s5_ppl": f"{ppl5:8.4f}",
            "score": f"{score:8.4f}",
        }
        return metrics
