from typing import Any, Dict
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from omegaconf import OmegaConf
from transformers import PretrainedConfig, BartModel
from collections import OrderedDict

from torchfly.training import FlyModel
from torchfly.metrics import CategoricalAccuracy, Average, MovingAverage, Speed

import memformers
from memformers.models.memformerA4_base.memformer_flymodel import MemformerModel, get_model_config, PretrainedConfig


class ClassificationHead(nn.Module):
    """Head for sentence-level classification tasks."""

    def __init__(
        self, input_dim: int, inner_dim: int, num_classes: int, pooler_dropout: float,
    ):
        super().__init__()
        self.dense = nn.Linear(input_dim, inner_dim)
        self.dropout = nn.Dropout(p=pooler_dropout)
        self.out_proj = nn.Linear(inner_dim, num_classes)

    def forward(self, hidden_states: torch.Tensor):
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.dense(hidden_states)
        hidden_states = torch.tanh(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.out_proj(hidden_states)
        return hidden_states


def process_weights(state_dict):
    new_state_dict = OrderedDict()
    for key, value in state_dict.items():
        new_state_dict[key[35:]] = value
    return new_state_dict


class MemformerClassificationFlyModel(FlyModel):
    def __init__(self, config):
        super().__init__(config)

        config = PretrainedConfig.from_dict(get_model_config())
        config.activation_dropout = 0.0
        self.model = MemformerModel(config)

        # if load from scratch
        # bart = BartModel.from_pretrained("facebook/bart-base")
        # print(self.model.load_state_dict(bart.state_dict(), strict=False))
        # if load pre-trained weights
        state_dict = torch.load(
            "../../../../Pretraining/exp_results/memformerA4_base/denosing_pretrain_copy_3/Trainer1_Stage1/Checkpoints/iter_645718_model_state.pth"
        )
        print(self.model.load_state_dict(process_weights(state_dict), strict=False))

        self.classification_head = ClassificationHead(
            self.model.config.d_model, self.model.config.d_model, num_classes=3, pooler_dropout=0.1,
        )
        self.loss_fct = nn.CrossEntropyLoss()

        # configure metrics here
        self.configure_metrics()

    def configure_metrics(self):
        self.training_metrics = {"loss": MovingAverage()}
        self.evaluation_metrics = {"loss": Average(), "acc": CategoricalAccuracy()}

    def forward(self, batch: Dict[str, Any]) -> Dict[str, Any]:
        outputs = self.model(
            input_ids=batch["input_ids"],
            attention_mask=batch["attention_mask"],
            decoder_input_ids=batch["input_ids"],
            return_dict=True,
        )
        hidden_states = outputs[0]

        eos_mask = batch["input_ids"].eq(self.model.config.eos_token_id)

        sentence_representation = hidden_states[eos_mask, :].view(hidden_states.size(0), -1, hidden_states.size(-1))[
            :, -1, :
        ]

        logits = self.classification_head(sentence_representation)

        loss = self.loss_fct(logits.view(-1, 3), batch["labels"].view(-1))
        self.training_metrics["loss"](loss.item())
        outputs.loss = loss
        return outputs

    def predict_step(self, batch):
        outputs = self.model(
            input_ids=batch["input_ids"],
            attention_mask=batch["attention_mask"],
            decoder_input_ids=batch["input_ids"],
            return_dict=True,
        )
        hidden_states = outputs[0]

        eos_mask = batch["input_ids"].eq(self.model.config.eos_token_id)

        sentence_representation = hidden_states[eos_mask, :].view(hidden_states.size(0), -1, hidden_states.size(-1))[
            :, -1, :
        ]

        logits = self.classification_head(sentence_representation)

        loss = self.loss_fct(logits.view(-1, 3), batch["labels"].view(-1))
        self.evaluation_metrics["loss"](loss.item())
        self.evaluation_metrics["acc"](predictions=logits.detach(), gold_labels=batch["labels"])
        return None

    def get_training_metrics(self) -> Dict[str, str]:
        loss = self.training_metrics["loss"].get_metric()
        metrics = {"loss": f"{loss:.4f}"}
        return metrics

    def get_evaluation_metrics(self) -> Dict[str, str]:
        loss = self.evaluation_metrics["loss"].get_metric()
        acc = self.evaluation_metrics["acc"].get_metric()
        metrics = {
            "loss": (f"{loss:.4f}", loss),
            "acc": (f"{acc:.4f}", acc),
            "score": (f"{acc:.4f}", acc),
        }
        return metrics
