from typing import Any, Dict, Tuple, List, Union
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchfly.metrics import Average, MovingAverage, Speed
from torchfly.training.schedulers import WarmupWarpper
from transformers import BartModel, PretrainedConfig

from .modeling_memformer import MemformerModel
from .memformer_memory import MemformerMemory
from .utils import get_model_config
from ...recurrent_training.recurrent_training_cell import RecurrentTrainingCell
from ...recurrent_training.recurrent_training_model import RecurrentTrainingModel
from ...recurrent_training.recurrent_training_flymodel import RecurrentTrainingFlyModel


class MemformerTrainingCell(RecurrentTrainingCell):
    def __init__(self, config):
        super().__init__()
        self.model_config = PretrainedConfig.from_dict(get_model_config())
        self.cell = MemformerModel(self.model_config)
        # load bart pre-trained weights
        bart_model = BartModel.from_pretrained("facebook/bart-large")
        print(self.cell.load_state_dict(bart_model.state_dict(), strict=False))

        # self.cell.encoder.memory_extract_tokens.data = self.cell.encoder.memory_extract_tokens.data * 0.99 + \
        #     (self.cell.shared.weight.data[:1] + self.cell.encoder.embed_positions.weight.data[:1]).expand(self.model_config.memory_extract_len, -1) * 0.01

        self.cell.shared.weight.requires_grad = False
        self.cell.encoder.embed_positions.weight.requires_grad = False
        self.cell.decoder.embed_positions.weight.requires_grad = False

    def forward(self, inputs: Any, memory: MemformerMemory) -> Tuple[Any, MemformerMemory]:
        reset_signals = inputs["reset"]
        reset_indices = torch.nonzero(reset_signals).squeeze(-1)
        num_reset = reset_indices.shape[0]

        memory_states = memory.memory_states

        if num_reset > 0:
            memory_states = memory_states.index_put((reset_indices,), self.cell.construct_memory(1))

        encoder_outputs = self.cell.encoder(
            input_ids=inputs["encoder_input_ids"],
            memory_states=memory_states,
            attention_mask=inputs["encoder_attention_mask"],
        )

        return (
            encoder_outputs,
            MemformerMemory(encoder_outputs.memory_states, len(reset_signals)),
        )

    def construct_memory(self, batch_size):
        return MemformerMemory(self.cell.construct_memory(batch_size), batch_size)


class MemformerFlyModel(RecurrentTrainingFlyModel):
    def __init__(self, config):
        super().__init__(config)
        recurrent_training_cell = MemformerTrainingCell(config)
        recurrent_training_model = RecurrentTrainingModel(recurrent_training_cell)
        self.model = recurrent_training_model
        # use -100 to avoid bugs
        self.loss_fct = nn.CrossEntropyLoss(ignore_index=-100)

    def forward(self, rollout, memory):
        model_outputs = self.model(rollout, memory, output_history_memories=True, output_rng_states=True)
        return model_outputs, model_outputs.memory

    def compute_step_loss(self, step_input, step_output):
        decoder_outputs = self.model.recurrent_training_cell.cell.decoder(
            input_ids=step_input["decoder_input_ids"],
            attention_mask=step_input["decoder_attention_mask"],
            encoder_hidden_states=step_output.last_hidden_state,
            encoder_attention_mask=step_output.encoder_attention_mask,
            return_dict=True,
        )
        lm_logits = F.linear(decoder_outputs.last_hidden_state, self.model.recurrent_training_cell.cell.shared.weight,)
        masked_lm_loss = self.loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), step_input["target"].view(-1))

        # log
        self.training_metrics["loss"](masked_lm_loss.item())
        # seq_len x batch_size
        self.training_metrics["tok/s"](
            step_input["target"].shape[0] * step_input["target"].shape[1] * self.config.training.num_gpus_per_node
        )
        self.training_metrics["mem_keep_prob_max"](step_output.memory_writer_attentions.max().item())
        self.training_metrics["mem_keep_prob_mean"](step_output.memory_writer_attentions[:, 0, :, 0].mean().item())

        return masked_lm_loss

    def configure_metrics(self):
        self.training_metrics = {
            "mem_grad_std": MovingAverage(name="mem_grad_std"),
            "mem_grad_max": MovingAverage(name="mem_grad_max"),
            "mem_keep_prob_max": MovingAverage(name="mem_keep_prob_max", beta=0.0),
            "mem_keep_prob_mean": MovingAverage(name="mem_keep_prob_mean", beta=0.0),
            "loss": MovingAverage(name="loss"),
            "tok/s": Speed(),
        }
        self.evaluation_metrics = {"loss": Average()}

    def get_training_metrics(self) -> Dict[str, str]:
        loss = self.training_metrics["loss"].get_metric()
        ppl = np.exp(loss)
        tok_s = self.training_metrics["tok/s"].get_metric()
        lr = self.get_last_lr()[0]
        mem_grad_std = self.training_metrics["mem_grad_std"].get_metric()
        mem_grad_max = self.training_metrics["mem_grad_max"].get_metric()
        mem_keep_prob_max = self.training_metrics["mem_keep_prob_max"].get_metric()
        mem_keep_prob_mean = self.training_metrics["mem_keep_prob_mean"].get_metric()

        metrics = {
            "tok/s": f"{tok_s:5.0f}",
            "lr": f"{lr:3.2e}",
            "mem_grad_std": f"{mem_grad_std:4.2e}",
            "mem_grad_max": f"{mem_grad_max:4.2e}",
            "mem_keep_prob_max": f"{mem_keep_prob_max:4.2e}",
            "mem_keep_prob_mean": f"{mem_keep_prob_mean:4.2e}",
            "loss": f"{loss:8.4f}",
            "ppl": f"{ppl:8.4f}",
        }

        return metrics

    def construct_memory(self, batch_size) -> MemformerMemory:
        return self.model.construct_memory(batch_size)

    def set_memory_params(self, memory: MemformerMemory):
        pass

    def configure_optimizers(self, config, total_num_update_steps) -> Union[List, List]:
        optimizer_grouped_parameters = self.model.parameters()

        betas = config.optimization.get("betas", (0.9, 0.999))
        warmup_steps = config.scheduler.warmup_steps

        optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=config.optimization.learning_rate, betas=betas)

        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, total_num_update_steps - warmup_steps, eta_min=config.scheduler.eta_min
        )

        scheduler = WarmupWarpper(scheduler, warmup_steps=warmup_steps, total_num_update_steps=total_num_update_steps)

        self.get_last_lr = scheduler.get_last_lr

        return [optimizer], [scheduler]
