from typing import Any, Dict, Tuple, List, Union
import numpy as np
import collections
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Any, Dict
from transformers import GPT2LMHeadModel, AutoTokenizer
from datasets import load_metric

from torchfly.training import FlyModel
from torchfly.metrics import CategoricalAccuracy, Average, MovingAverage, Speed


class GPTGenerationFlyModel(FlyModel):
    def __init__(self, config):
        super().__init__(config)
        self.gpt2 = GPT2LMHeadModel.from_pretrained(
            config.task.pretrained_model, resid_pdrop=0.0, embd_pdrop=0.0, attn_pdrop=0.0
        )
        self.tokenizer = AutoTokenizer.from_pretrained(config.task.pretrained_model)
        # configure metrics here
        self.configure_metrics()

    def configure_metrics(self):
        self.training_metrics = {"loss": MovingAverage()}
        self.evaluation_metrics = {"loss": Average()}

    def forward(self, batch: Dict[str, Any]) -> Dict[str, Any]:
        batch = batch[0]
        output = self.gpt2(
            input_ids=batch["input_ids"],
            attention_mask=batch["attention_mask"],
            labels=batch["labels"],
            return_dict=True,
        )
        self.training_metrics["loss"](output.loss.item())
        return output

    def predict_step(self, batch):
        output = self.gpt2(
            input_ids=batch["input_ids"],
            attention_mask=batch["attention_mask"],
            labels=batch["labels"],
            return_dict=True,
        )
        self.evaluation_metrics["loss"](output.loss.item())

        return None


    def get_training_metrics(self) -> Dict[str, str]:
        loss = self.training_metrics["loss"].get_metric()
        metrics = {"loss": f"{loss:.4f}"}
        return metrics

    def get_evaluation_metrics(self) -> Dict[str, str]:
        loss = self.evaluation_metrics["loss"].get_metric()
        ppl = np.exp(loss)
        score = -ppl

        metrics = {
            "loss": f"{loss:8.4f}",
            "ppl": f"{ppl:8.4f}",
            "score": f"{score:8.4f}",
        }

        return metrics

    def validation_loop(self, dataloader):
        if self.trainer.global_step_count > 0:
            super().validation_loop(dataloader)

    def reset_evaluation_metrics(self):
        super().reset_evaluation_metrics()
        self.gts = []
        self.preds = []
