from typing import Any, Dict, Tuple, List, Union
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchfly.training import FlyModel
from torchfly.metrics import Average, MovingAverage, Speed
from torchfly.training.schedulers import WarmupWarpper
from transformers import BartModel, BartForConditionalGeneration, PretrainedConfig

from .utils import get_model_config


class BartBaseFlyModel(FlyModel):
    def __init__(self, config):
        super().__init__(config)
        self.model = BartForConditionalGeneration.from_pretrained("facebook/bart-base")

    def forward(self, batch):
        model_outputs = self.model(batch["encoder_input_ids"], 
                                attention_mask=batch["encoder_attention_mask"], 
                                decoder_input_ids=batch["decoder_input_ids"],
                                labels=batch["target"],
                                return_dict=True)
        self.training_metrics["loss"](model_outputs.loss.item())
        return model_outputs

    def configure_metrics(self):
        self.training_metrics = {
            "loss": MovingAverage(name="loss"),
            "tok/s": Speed(),
        }
        self.evaluation_metrics = {"loss": Average()}

    def get_training_metrics(self) -> Dict[str, str]:
        loss = self.training_metrics["loss"].get_metric()
        ppl = np.exp(loss)
        lr = self.get_last_lr()[0]

        metrics = {
            "lr": f"{lr:3.2e}",
            "loss": f"{loss:8.4f}",
            "ppl": f"{ppl:8.4f}",
        }
        return metrics