from typing import Any, Dict, Tuple, List, Union
from collections import OrderedDict
import tqdm
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchfly.metrics import Average, MovingAverage, Speed
from torchfly.training.schedulers import WarmupWarpper
from torchfly.utilities import move_to_device

from transformers import BartModel, PretrainedConfig

from memformers.models.memformerA4_base.memformer_memory import MemformerMemory
from memformers.models.memformerA4_base.modeling_memformer import MemformerForConditionalGeneration
from memformers.models.memformerA4_base.utils import get_model_config
from memformers.recurrent_training.recurrent_training_cell import RecurrentTrainingCell
from memformers.recurrent_training.recurrent_training_model import RecurrentTrainingModel
from memformers.recurrent_training.recurrent_training_flymodel import RecurrentTrainingFlyModel

def process_weights(state_dict):
    new_state_dict = OrderedDict()
    for key, value in state_dict.items():
        key = key.replace(".recurrent_training_cell.cell", "")
        new_state_dict[key] = value
    return new_state_dict

class MemformerTrainingCell(RecurrentTrainingCell):
    def __init__(self, config):
        super().__init__()
        self.model_config = get_model_config()
        self.model_config.activation_dropout = 0.0
        self.cell = MemformerForConditionalGeneration(self.model_config)

        state_dict = torch.load(config.task.pretrained_weights_path)
        state_dict = process_weights(state_dict)
        print(self.cell.load_state_dict(state_dict, strict=False))

    def forward(self, inputs: Any, memory: MemformerMemory) -> Tuple[Any, MemformerMemory]:
        reset_signals = inputs["reset"]
        reset_indices = torch.nonzero(reset_signals).squeeze(-1)
        num_reset = reset_indices.shape[0]

        memory_states = memory.memory_states

        if num_reset > 0:
            memory_states = memory_states.index_put((reset_indices,), self.cell.construct_memory(1))

        encoder_outputs = self.cell.model.encoder(
            input_ids=inputs["encoder_input_ids"],
            memory_states=memory_states,
            attention_mask=inputs["encoder_attention_mask"],
        )

        return (
            encoder_outputs,
            MemformerMemory(encoder_outputs.memory_states, len(reset_signals)),
        )

    def construct_memory(self, batch_size):
        return MemformerMemory(self.cell.model.construct_memory(batch_size), batch_size)


class MemformerFlyModel(RecurrentTrainingFlyModel):
    def __init__(self, config):
        super().__init__(config)
        recurrent_training_cell = MemformerTrainingCell(config)
        recurrent_training_model = RecurrentTrainingModel(recurrent_training_cell)
        self.model = recurrent_training_model
        # use -100 to avoid bugs
        self.loss_fct = nn.CrossEntropyLoss(ignore_index=-100)

    def forward(self, rollout, memory):
        model_outputs = self.model(rollout, memory, output_history_memories=True, output_rng_states=True)
        return model_outputs, model_outputs.memory

    def compute_step_loss(self, step_input, step_output, training=True):
        decoder_outputs = self.model.recurrent_training_cell.cell.model.decoder(
            input_ids=step_input["decoder_input_ids"],
            attention_mask=step_input["decoder_attention_mask"],
            encoder_hidden_states=step_output.last_hidden_state,
            encoder_attention_mask=step_output.encoder_attention_mask,
            return_dict=True,
        )
        lm_logits = F.linear(decoder_outputs.last_hidden_state, self.model.recurrent_training_cell.cell.model.shared.weight,)
        lm_loss = self.loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), step_input["target"].view(-1))

        if training:
            # log
            self.training_metrics["loss"](lm_loss.item())
            # seq_len x batch_size
            self.training_metrics["tok/s"](
                step_input["target"].shape[0] * step_input["target"].shape[1] * self.config.training.num_gpus_per_node
            )
            self.training_metrics["mem_keep_prob_max"](step_output.memory_writer_attentions.max().item())
            self.training_metrics["mem_keep_prob_mean"](step_output.memory_writer_attentions[:, 0, :, 0].mean().item())
        else:
            self.evaluation_metrics["loss"](lm_loss.item())

        return lm_loss

    def predict_step(self, batch_idx, batch, memory):
        step_output, new_memory = self(batch, memory)
        loss = self.compute_step_loss(batch[0], step_output.outputs[0], training=False)
        return loss, new_memory

    def validation_loop(self, dataloader):
        # No gradient is needed for validation
        self.eval()
        self.reset_evaluation_metrics()
        with torch.no_grad():
            pbar = tqdm.tqdm(dataloader)
            pbar.mininterval = 2.0
            memory = self.construct_memory(self.config.training.evaluation.batch_size)
            for batch_idx, batch in enumerate(pbar):
                batch = move_to_device(batch, self.device)
                _, memory = self.predict_step(batch_idx, [batch], memory)

    #     with torch.no_grad():
    #         # Progress bar
    #         pbar = tqdm.tqdm(self.validation_dataloader) if self.rank == 0 else self.validation_dataloader
    #         pbar.mininterval = 2.0
    #         # Initialize memory
    #         memory = self.model.construct_memory(self.config.evaluation.batch_size)

    #         for batch in pbar:
    #             batch = move_to_device(batch, self.device)
    #             _, memory = self.model.predict([batch], memory)


    def configure_metrics(self):
        self.training_metrics = {
            "mem_grad_std": MovingAverage(name="mem_grad_std"),
            "mem_grad_max": MovingAverage(name="mem_grad_max"),
            "mem_keep_prob_max": MovingAverage(name="mem_keep_prob_max", beta=0.0),
            "mem_keep_prob_mean": MovingAverage(name="mem_keep_prob_mean", beta=0.0),
            "loss": MovingAverage(name="loss"),
            "tok/s": Speed(),
        }
        self.evaluation_metrics = {"loss": Average()}

    def get_training_metrics(self) -> Dict[str, str]:
        loss = self.training_metrics["loss"].get_metric()
        ppl = np.exp(loss)
        tok_s = self.training_metrics["tok/s"].get_metric()
        lr = self.get_last_lr()[0]
        mem_grad_std = self.training_metrics["mem_grad_std"].get_metric()
        mem_grad_max = self.training_metrics["mem_grad_max"].get_metric()
        mem_keep_prob_max = self.training_metrics["mem_keep_prob_max"].get_metric()
        mem_keep_prob_mean = self.training_metrics["mem_keep_prob_mean"].get_metric()

        metrics = {
            "tok/s": f"{tok_s:5.0f}",
            "lr": f"{lr:3.2e}",
            "mem_grad_std": f"{mem_grad_std:4.2e}",
            "mem_grad_max": f"{mem_grad_max:4.2e}",
            "mem_keep_prob_max": f"{mem_keep_prob_max:4.2e}",
            "mem_keep_prob_mean": f"{mem_keep_prob_mean:4.2e}",
            "loss": f"{loss:8.4f}",
            "ppl": f"{ppl:8.4f}",
        }

        return metrics

    def get_evaluation_metrics(self) -> Dict[str, str]:
        loss = self.evaluation_metrics["loss"].get_metric()
        ppl = np.exp(loss)
        score = -ppl

        metrics = {
            "loss": f"{loss:8.4f}",
            "ppl": f"{ppl:8.4f}",
            "score": f"{score:8.4f}",
        }

        return metrics

    def construct_memory(self, batch_size) -> MemformerMemory:
        return self.model.recurrent_training_cell.construct_memory(batch_size)

    def set_memory_params(self, memory: MemformerMemory):
        pass

    def configure_optimizers(self, config, total_num_update_steps) -> Union[List, List]:
        optimizer_grouped_parameters = self.model.parameters()

        betas = config.optimization.get("betas", (0.9, 0.999))
        warmup_steps = config.scheduler.warmup_steps

        optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=config.optimization.learning_rate, betas=betas)

        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, total_num_update_steps - warmup_steps, eta_min=config.scheduler.eta_min
        )

        scheduler = WarmupWarpper(scheduler, warmup_steps=warmup_steps, total_num_update_steps=total_num_update_steps)

        self.get_last_lr = scheduler.get_last_lr

        return [optimizer], [scheduler]
