import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Any, Dict
from transformers import BartModel, BartForSequenceClassification

from torchfly.training import FlyModel
from torchfly.metrics import CategoricalAccuracy, Average, MovingAverage, Speed


class BartClassificationFlyModel(FlyModel):
    def __init__(self, config):
        super().__init__(config)
        self.bart = BartForSequenceClassification.from_pretrained(
            config.task.pretrained_model, num_labels=3, problem_type="single_label_classification", dropout=0.0, classifier_dropout=0.1)
        # configure metrics here
        self.configure_metrics()

    def configure_metrics(self):
        self.training_metrics = {"loss": MovingAverage()}
        self.evaluation_metrics = {
            "loss": Average(), "acc": CategoricalAccuracy()}

    def forward(self, batch: Dict[str, Any]) -> Dict[str, Any]:
        output = self.bart(input_ids=batch["input_ids"], decoder_input_ids=batch["input_ids"],
                           attention_mask=batch["attention_mask"], labels=batch["labels"], return_dict=True)
        self.training_metrics["loss"](output.loss.item())
        return output

    def predict_step(self, batch):
        output = self.bart(input_ids=batch["input_ids"], decoder_input_ids=batch["input_ids"],
                           attention_mask=batch["attention_mask"], labels=batch["labels"], return_dict=True)
        self.evaluation_metrics["loss"](output.loss.item())
        self.evaluation_metrics["acc"](
            predictions=output.logits, gold_labels=batch["labels"])
        return None

    def get_training_metrics(self) -> Dict[str, str]:
        loss = self.training_metrics["loss"].get_metric()
        metrics = {"loss": f"{loss:.4f}"}
        return metrics

    def get_evaluation_metrics(self) -> Dict[str, str]:
        loss = self.evaluation_metrics["loss"].get_metric()
        acc = self.evaluation_metrics["acc"].get_metric()
        metrics = {"loss": (f"{loss:.4f}", loss), "acc": (
            f"{acc:.4f}", acc), "score": (f"{acc:.4f}", acc)}
        return metrics
