import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Any, Dict
from transformers import BartModel, BartForConditionalGeneration, GPT2LMHeadModel, AutoTokenizer

from torchfly.training import FlyModel
from torchfly.metrics import CategoricalAccuracy, Average, MovingAverage, Speed


class GPTGenerationFlyModel(FlyModel):
    def __init__(self, config):
        super().__init__(config)
        self.gpt2 = GPT2LMHeadModel.from_pretrained(
            config.task.pretrained_model, resid_pdrop=0.0, embd_pdrop=0.0, attn_pdrop=0.0
        )
        self.tokenizer = AutoTokenizer.from_pretrained(config.task.pretrained_model)
        self.eval_loss_fct = nn.CrossEntropyLoss(ignore_index=-100, reduction="none")

        # configure metrics here
        self.configure_metrics()

    def configure_metrics(self):
        self.training_metrics = {
            "loss": MovingAverage(name="loss"),
        }
        self.evaluation_metrics = {
            "loss": Average(),
            "session1_loss": Average(),
            "session2_loss": Average(),
            "session3_loss": Average(),
            "session4_loss": Average(),
            "session5_loss": Average(),
        }

    def forward(self, batch: Dict[str, Any]) -> Dict[str, Any]:
        batch = batch[0]
        output = self.gpt2(
            input_ids=batch["input_ids"],
            attention_mask=batch["attention_mask"],
            labels=batch["labels"],
            return_dict=True,
        )
        self.training_metrics["loss"](output.loss.item())
        return output

    def predict_step(self, batch):
        output = self.gpt2(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"], return_dict=True,)

        batch_size = len(batch["session_id"])
        shift_logits = output.logits[..., :-1, :].contiguous()
        shift_labels = batch["labels"][..., 1:].contiguous()
        

        lm_loss = self.eval_loss_fct(shift_logits.view(-1, shift_logits.shape[-1]), shift_labels.view(-1)).view(
            batch_size, -1
        )
        lm_loss = lm_loss.sum(-1) / ((shift_labels != -100).sum(-1) + 1e-5)

        session_losses = lm_loss.tolist()

        for batch_idx in range(batch_size):
            sess_id = batch["session_id"][batch_idx]
            sess_loss = session_losses[batch_idx]
            if sess_id != "none":
                self.evaluation_metrics[f"session{sess_id}_loss"](sess_loss)
                self.evaluation_metrics["loss"](sess_loss)
        # lm_loss = lm_loss.mean()
        return None

    def get_training_metrics(self) -> Dict[str, str]:
        loss = self.training_metrics["loss"].get_metric()
        metrics = {"loss": f"{loss:.4f}"}
        return metrics

    def get_evaluation_metrics(self) -> Dict[str, str]:
        loss = self.evaluation_metrics["loss"].get_metric()
        ppl = np.exp(loss)
        session1_loss = self.evaluation_metrics["session1_loss"].get_metric()
        session2_loss = self.evaluation_metrics["session2_loss"].get_metric()
        session3_loss = self.evaluation_metrics["session3_loss"].get_metric()
        session4_loss = self.evaluation_metrics["session4_loss"].get_metric()
        session5_loss = self.evaluation_metrics["session5_loss"].get_metric()
        ppl1 = np.exp(session1_loss)
        ppl2 = np.exp(session2_loss)
        ppl3 = np.exp(session3_loss)
        ppl4 = np.exp(session4_loss)
        ppl5 = np.exp(session5_loss)

        score = -ppl5

        metrics = {
            "loss": f"{loss:8.4f}",
            "ppl": f"{ppl:8.4f}",
            "s1_ppl": f"{ppl1:8.4f}",
            "s2_ppl": f"{ppl2:8.4f}",
            "s3_ppl": f"{ppl3:8.4f}",
            "s4_ppl": f"{ppl4:8.4f}",
            "s5_ppl": f"{ppl5:8.4f}",
            "score": f"{score:8.4f}",
        }
        return metrics
