import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import Any, Dict
from transformers import BartModel, AutoTokenizer
from datasets import load_metric

from torchfly.training import FlyModel
from torchfly.metrics import CategoricalAccuracy, Average, MovingAverage, Speed




def get_best_span(
    start_logits,
    end_logits,
    n_best_size: int = 10,
    max_answer_length: int = 30,
):
    batch_size = start_logits.shape[0]
    predictions = []
    
    for example_index in range(batch_size):
        # Go through all possibilities for the `n_best_size` greater start and end logits.
        start_indexes = torch.argsort(start_logits[example_index]).tolist()[-1 : -n_best_size - 1 : -1]
        end_indexes = torch.argsort(end_logits[example_index]).tolist()[-1 : -n_best_size - 1 : -1]

        prelim_predictions = []

        for start_index in start_indexes:
            for end_index in end_indexes:
                # Don't consider answers with a length that is either < 0 or > max_answer_length.
                if end_index < start_index or end_index - start_index + 1 > max_answer_length:
                    continue
                # Don't consider answer that don't have the maximum context available (if such information is
                # provided).

                prelim_predictions.append(
                    {
                        "score": start_logits[example_index][start_index] + end_logits[example_index][end_index],
                        "start_logit": start_logits[example_index][start_index],
                        "end_logit": end_logits[example_index][end_index],
                        "start_index": start_index,
                        "end_index": end_index
                    }
                )

        # Only keep the best `n_best_size` predictions.
        prelim_predictions = sorted(prelim_predictions, key=lambda x: x["score"], reverse=True)[:n_best_size]

        if len(prelim_predictions) == 0:
            predictions.append((0, 0))
        else:
            predictions.append((prelim_predictions[0]["start_index"], prelim_predictions[0]["end_index"]))

    return predictions


class BartEncoderSpanQAFlyModel(FlyModel):
    def __init__(self, config):
        super().__init__(config)

        self.bart = BartModel.from_pretrained(
            config.task.pretrained_model, dropout=0.0, classifier_dropout=0.0, activation_dropout=0.0, attention_dropout=0.0
        )

        self.qa_outputs = nn.Linear(self.bart.config.hidden_size, 2)
        self.loss_fct = nn.CrossEntropyLoss()
        self.tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base")

        # configure metrics here
        self.configure_metrics()

    def configure_metrics(self):
        self.training_metrics = {"loss": MovingAverage()}
        self.evaluation_metrics = {"loss": Average(), "acc": CategoricalAccuracy()}

        self.metric = load_metric("squad")
        self.all_predictions = []
        self.all_references = []

    def forward(self, batch: Dict[str, Any]) -> Dict[str, Any]:
        outputs = self.bart.encoder(
            input_ids=batch["input_ids"], attention_mask=batch["attention_mask"], return_dict=True
        )
        sequence_output = outputs[0]

        logits = self.qa_outputs(sequence_output)
        start_logits, end_logits = logits.split(1, dim=-1)
        start_logits = start_logits.squeeze(-1).contiguous()
        end_logits = end_logits.squeeze(-1).contiguous()

        ignored_index = start_logits.size(1)
        start_positions = batch["start_positions"].clamp(0, ignored_index)
        end_positions = batch["end_positions"].clamp(0, ignored_index)

        start_loss = self.loss_fct(start_logits, start_positions)
        end_loss = self.loss_fct(end_logits, end_positions)
        loss = (start_loss + end_loss) / 2

        self.training_metrics["loss"](loss.item())
        outputs.loss = loss
        return outputs

    def predict_step(self, batch):
        outputs = self.bart.encoder(
            input_ids=batch["input_ids"], attention_mask=batch["attention_mask"], return_dict=True
        )
        sequence_output = outputs[0]

        logits = self.qa_outputs(sequence_output)
        start_logits, end_logits = logits.split(1, dim=-1)
        start_logits = start_logits.squeeze(-1).contiguous()
        end_logits = end_logits.squeeze(-1).contiguous()

        ignored_index = start_logits.size(1)
        start_positions = batch["start_positions"].clamp(0, ignored_index)
        end_positions = batch["end_positions"].clamp(0, ignored_index)

        start_loss = self.loss_fct(start_logits, start_positions)
        end_loss = self.loss_fct(end_logits, end_positions)
        loss = (start_loss + end_loss) / 2

        self.evaluation_metrics["loss"](loss.item())

        predictions = get_best_span(start_logits, end_logits,)     

        for i in range(start_logits.shape[0]):
            pred = (
                self.tokenizer.decode(batch["input_ids"][i][predictions[i][0] : predictions[i][1] + 1])
                .strip()
                .lower()
            )
            reference = (
                self.tokenizer.decode(batch["input_ids"][i][start_positions[i] : end_positions[i] + 1]).strip().lower()
            )

            self.all_predictions.append({"prediction_text": pred, "id": str(len(self.all_predictions))})
            self.all_references.append({"answers": {"answer_start": [0], "text": [reference]}, "id": str(len(self.all_references))})

        return None

    def get_training_metrics(self) -> Dict[str, str]:
        loss = self.training_metrics["loss"].get_metric()
        metrics = {"loss": f"{loss:.4f}"}
        return metrics

    def get_evaluation_metrics(self) -> Dict[str, str]:
        loss = self.evaluation_metrics["loss"].get_metric()
        results = self.metric.compute(predictions=self.all_predictions, references=self.all_references)

        metrics = {
            "loss": (f"{loss:.4f}", loss),
            "em": (f"{results['exact_match']:.4f}", results["exact_match"]),
            "f1": (f"{results['f1']:.4f}", results["f1"]),
            "score": (f"{results['f1']:.4f}", results["f1"])
        }
        return metrics

    def reset_evaluation_metrics(self):
        super().reset_evaluation_metrics()
        self.all_predictions = []
        self.all_references = []
