import numpy as np
from typing import Any, Dict, Tuple, List, Union
import collections
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Any, Dict
from transformers import BartModel, BartForConditionalGeneration, AutoTokenizer
from transformers.models.bart.modeling_bart import BaseModelOutput
from datasets import load_metric

from torchfly.training import FlyModel
from torchfly.metrics import CategoricalAccuracy, Average, MovingAverage, Speed


class BartGenerationFlyModel(FlyModel):
    def __init__(self, config):
        super().__init__(config)
        self.bart = BartForConditionalGeneration.from_pretrained(
            config.task.pretrained_model, dropout=0.0, attention_dropout=0.0, activation_dropout=0.0
        )
        # configure metrics here
        self.configure_metrics()
        self.metric = load_metric("squad")
        self.gts = []
        self.preds = []
        self.tokenizer = AutoTokenizer.from_pretrained(config.task.pretrained_model)
        self.eos_token_id = self.tokenizer.encode("\n\n", add_special_tokens=False)[0]

    def configure_metrics(self):
        self.training_metrics = {"loss": MovingAverage()}
        self.evaluation_metrics = {"loss": Average(), "f1": Average()}

    def forward(self, batch: Dict[str, Any]) -> Dict[str, Any]:
        batch = batch[0]
        output = self.bart(
            input_ids=batch["encoder_input_ids"],
            decoder_input_ids=batch["decoder_input_ids"],
            attention_mask=batch["encoder_attention_mask"],
            labels=batch["target"],
            return_dict=True,
        )
        self.training_metrics["loss"](output.loss.item())
        return output

    def predict_step(self, batch):
        output = self.bart(
            input_ids=batch["encoder_input_ids"],
            decoder_input_ids=batch["decoder_input_ids"],
            attention_mask=batch["encoder_attention_mask"],
            labels=batch["target"],
            return_dict=True,
        )
        self.evaluation_metrics["loss"](output.loss.item())

        encoder_output = BaseModelOutput(
            last_hidden_state=output.encoder_last_hidden_state, hidden_states=output.encoder_hidden_states, attentions=output.encoder_attentions
        )

        outputs = self.bart.generate(
            input_ids=batch["encoder_input_ids"],
            attention_mask=batch["encoder_attention_mask"],
            decoder_start_token_id=self.tokenizer.bos_token_id,
            eos_token_id=self.eos_token_id,
            forced_eos_token_id=self.eos_token_id,
            max_length=128,
            length_penalty=0.0,
            num_beams=4,
            do_sample=False,
            return_dict_in_generate=True,
        )

        all_gt_tokens = batch["decoder_input_ids"].tolist()
        all_gen_tokens = outputs.sequences.tolist()
        for idx in range(len(outputs.sequences)):
            # gen_token_ids = [
            #     token_id
            #     for token_id in outputs.sequences[idx].tolist()
            #     if token_id not in self.tokenizer.all_special_ids
            # ]
            # gt_token_ids = [
            #     token_id for token_id in all_gt_tokens[idx] if token_id not in self.tokenizer.all_special_ids
            # ]
            if not batch["if_empty"][idx]:
                self.gts.append("B:"+self.tokenizer.decode(all_gt_tokens[idx], skip_special_tokens=True).strip().split(":")[1])
                self.preds.append("B:"+self.tokenizer.decode(all_gen_tokens[idx], skip_special_tokens=True).strip().split(":")[1])

        return None

    def get_training_metrics(self) -> Dict[str, str]:
        loss = self.training_metrics["loss"].get_metric()
        metrics = {"loss": f"{loss:.4f}"}
        return metrics

    def get_evaluation_metrics(self) -> Dict[str, str]:
        loss = self.evaluation_metrics["loss"].get_metric()

        preds = [{"prediction_text": text, "id": str(idx)} for idx, text in enumerate(self.preds)]
        gts = [
            {"answers": {"answer_start": [100], "text": [text]}, "id": str(idx)} for idx, text in enumerate(self.gts)
        ]
        if len(preds) > 0:
            results = self.metric.compute(predictions=preds, references=gts)
        else:
            results = {"em": 0.0, "f1": 0.0}

        self.metric
        ppl = np.exp(loss)
        score = -ppl

        metrics = {
            "loss": f"{loss:8.4f}",
            "ppl": f"{ppl:8.4f}",
            "f1": f"{results['f1']:8.4f}",
            "score": f"{score:8.4f}",
        }

        return metrics

    def validation_loop(self, dataloader):
        if self.trainer.global_step_count >= 0:
            super().validation_loop(dataloader)

    def reset_evaluation_metrics(self):
        super().reset_evaluation_metrics()
        self.gts = []
        self.preds = []
