from typing import Any, Dict, Tuple, List, Union
from collections import OrderedDict
import tqdm
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchfly.metrics import Average, MovingAverage, Speed
from torchfly.training.schedulers import WarmupWarpper
from torchfly.utilities import move_to_device

from transformers import BartModel, PretrainedConfig
from transformers import BartModel, BartForConditionalGeneration

from memformers.models.BartPrompt_base.modeling_BartPrompt import BartPromptForConditionalGeneration
from memformers.models.BartPrompt_base.utils import get_model_config
from memformers.recurrent_training.recurrent_training_cell import RecurrentTrainingCell
from memformers.recurrent_training.recurrent_training_model import RecurrentTrainingModel
from memformers.recurrent_training.recurrent_training_flymodel import RecurrentTrainingFlyModel

from dummy_memory import NoneMemory

def process_weights(state_dict):
    new_state_dict = OrderedDict()
    for key, value in state_dict.items():
        key = key.replace(".recurrent_training_cell.cell", "")
        new_state_dict[key] = value
    return new_state_dict

class BartTrainingCell(RecurrentTrainingCell):
    def __init__(self, config):
        super().__init__()
        self.model_config = PretrainedConfig.from_dict(get_model_config())
        # self.model_config.attention_dropout = 0.1
        self.cell = BartPromptForConditionalGeneration(self.model_config)

        bart_model = BartForConditionalGeneration.from_pretrained("facebook/bart-base")
        print(self.cell.load_state_dict(bart_model.state_dict(), strict=False))

    def forward(self, inputs: Any, memory: NoneMemory) -> Tuple[Any, NoneMemory]:
        encoder_outputs = self.cell.model.encoder(
            input_ids=inputs["encoder_input_ids"],
            attention_mask=inputs["encoder_attention_mask"],
        )

        return (encoder_outputs, NoneMemory(None, 1))

    def construct_memory(self, batch_size):
        return NoneMemory(None, 1)


class BartFlyModel(RecurrentTrainingFlyModel):
    def __init__(self, config):
        super().__init__(config)
        recurrent_training_cell = BartTrainingCell(config)
        recurrent_training_model = RecurrentTrainingModel(recurrent_training_cell)
        self.model = recurrent_training_model
        # use -100 to avoid bugs
        self.loss_fct = nn.CrossEntropyLoss(ignore_index=-100)

    def forward(self, rollout, memory):
        model_outputs = self.model(rollout, memory, output_history_memories=True, output_rng_states=True)
        return model_outputs, None

    def compute_step_loss(self, step_input, step_output, training=True):
        decoder_outputs = self.model.recurrent_training_cell.cell.model.decoder(
            input_ids=step_input["decoder_input_ids"],
            attention_mask=step_input["decoder_attention_mask"],
            encoder_hidden_states=step_output.last_hidden_state,
            # encoder_attention_mask=step_output.encoder_attention_mask,
            return_dict=True,
        )
        lm_logits = F.linear(decoder_outputs.last_hidden_state, self.model.recurrent_training_cell.cell.model.shared.weight,)
        lm_loss = self.loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), step_input["target"].view(-1))

        if training:
            # log
            self.training_metrics["loss"](lm_loss.item())
            # seq_len x batch_size
            self.training_metrics["tok/s"](
                step_input["target"].shape[0] * step_input["target"].shape[1] * self.config.training.num_gpus_per_node
            )
        else:
            self.evaluation_metrics["loss"](lm_loss.item())

        return lm_loss

    def predict_step(self, batch_idx, batch, memory):
        step_output, new_memory = self(batch, memory)
        loss = self.compute_step_loss(batch[0], step_output.outputs[0], training=False)
        return loss, new_memory

    def validation_loop(self, dataloader):
        # No gradient is needed for validation
        self.eval()
        self.reset_evaluation_metrics()
        with torch.no_grad():
            pbar = tqdm.tqdm(dataloader)
            pbar.mininterval = 2.0
            memory = self.construct_memory(self.config.training.evaluation.batch_size)
            for batch_idx, batch in enumerate(pbar):
                batch = move_to_device(batch, self.device)
                _, memory = self.predict_step(batch_idx, [batch], memory)

    #     with torch.no_grad():
    #         # Progress bar
    #         pbar = tqdm.tqdm(self.validation_dataloader) if self.rank == 0 else self.validation_dataloader
    #         pbar.mininterval = 2.0
    #         # Initialize memory
    #         memory = self.model.construct_memory(self.config.evaluation.batch_size)

    #         for batch in pbar:
    #             batch = move_to_device(batch, self.device)
    #             _, memory = self.model.predict([batch], memory)


    def configure_metrics(self):
        self.training_metrics = {
            "loss": MovingAverage(name="loss"),
            "tok/s": Speed(),
        }
        self.evaluation_metrics = {"loss": Average()}

    def get_training_metrics(self) -> Dict[str, str]:
        loss = self.training_metrics["loss"].get_metric()
        ppl = np.exp(loss)
        tok_s = self.training_metrics["tok/s"].get_metric()
        lr = self.get_last_lr()[0]

        metrics = {
            "tok/s": f"{tok_s:5.0f}",
            "lr": f"{lr:3.2e}",
            "loss": f"{loss:8.4f}",
            "ppl": f"{ppl:8.4f}",
        }

        return metrics

    def get_evaluation_metrics(self) -> Dict[str, str]:
        loss = self.evaluation_metrics["loss"].get_metric()
        ppl = np.exp(loss)
        score = -ppl

        metrics = {
            "loss": f"{loss:8.4f}",
            "ppl": f"{ppl:8.4f}",
            "score": f"{score:8.4f}",
        }

        return metrics

    def construct_memory(self, batch_size) -> NoneMemory:
        return self.model.recurrent_training_cell.construct_memory(batch_size)

    def set_memory_params(self, memory: NoneMemory):
        pass

    # def configure_optimizers(self, config, total_num_update_steps) -> Union[List, List]:
    #     optimizer_grouped_parameters = self.model.parameters()

    #     betas = config.optimization.get("betas", (0.9, 0.999))
    #     warmup_steps = config.scheduler.warmup_steps

    #     optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=config.optimization.learning_rate, betas=betas)

    #     scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    #         optimizer, total_num_update_steps - warmup_steps, eta_min=config.scheduler.eta_min
    #     )

    #     scheduler = WarmupWarpper(scheduler, warmup_steps=warmup_steps, total_num_update_steps=total_num_update_steps)

    #     self.get_last_lr = scheduler.get_last_lr

    #     return [optimizer], [scheduler]
