import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Any, Dict
from transformers import BartModel

from torchfly.training import FlyModel
from torchfly.metrics import CategoricalAccuracy, Average, MovingAverage, Speed


class BartClassificationHead(nn.Module):
    """Head for sentence-level classification tasks."""

    def __init__(
        self,
        input_dim: int,
        inner_dim: int,
        num_classes: int,
        pooler_dropout: float,
    ):
        super().__init__()
        self.dense = nn.Linear(input_dim, inner_dim)
        self.dropout = nn.Dropout(p=pooler_dropout)
        self.out_proj = nn.Linear(inner_dim, num_classes)

    def forward(self, hidden_states: torch.Tensor):
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.dense(hidden_states)
        hidden_states = torch.tanh(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.out_proj(hidden_states)
        return hidden_states

class BartEncoderClassificationFlyModel(FlyModel):
    def __init__(self, config):
        super().__init__(config)
        
        self.bart = BartModel.from_pretrained(
            config.task.pretrained_model, dropout=0.1, classifier_dropout=0.1, activation_dropout=0.1)
        
        self.classification_head = BartClassificationHead(
                    self.bart.config.d_model,
                    self.bart.config.d_model,
                    num_classes=3,
                    pooler_dropout=0.1,
        )
        self.loss_fct = nn.CrossEntropyLoss()

        # configure metrics here
        self.configure_metrics()

    def configure_metrics(self):
        self.training_metrics = {"loss": MovingAverage()}
        self.evaluation_metrics = {
            "loss": Average(), "acc": CategoricalAccuracy()}

    def forward(self, batch: Dict[str, Any]) -> Dict[str, Any]:
        output = self.bart.encoder(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"], return_dict=True)
        
        hidden_states = output[0]
        bos_mask = batch["input_ids"].eq(self.bart.config.eos_token_id)
        sentence_representation = hidden_states[bos_mask, :].view(hidden_states.size(0), -1, hidden_states.size(-1))
        logits = self.classification_head(sentence_representation)
        logits = logits.mean(1)

        loss = self.loss_fct(logits.view(-1, 3), batch["labels"].view(-1))
        self.training_metrics["loss"](loss.item())
        output.loss = loss
        return output

    def predict_step(self, batch):
        output = self.bart.encoder(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"], return_dict=True)

        hidden_states = output[0]
        bos_mask = batch["input_ids"].eq(self.bart.config.eos_token_id)
        sentence_representation = hidden_states[bos_mask, :].view(hidden_states.size(0), -1, hidden_states.size(-1))
        logits = self.classification_head(sentence_representation)
        logits = logits.mean(1)

        loss = self.loss_fct(logits.view(-1, 3), batch["labels"].view(-1))
        self.evaluation_metrics["loss"](loss.item())
        self.evaluation_metrics["acc"](predictions=logits.detach(), gold_labels=batch["labels"])
        return None

    def get_training_metrics(self) -> Dict[str, str]:
        loss = self.training_metrics["loss"].get_metric()
        metrics = {"loss": f"{loss:.4f}"}
        return metrics

    def get_evaluation_metrics(self) -> Dict[str, str]:
        loss = self.evaluation_metrics["loss"].get_metric()
        acc = self.evaluation_metrics["acc"].get_metric()
        metrics = {"loss": (f"{loss:.4f}", loss), "acc": (
            f"{acc:.4f}", acc), "score": (f"{acc:.4f}", acc)}
        return metrics
