from typing import Any, Dict, Tuple, List, Union
import collections
from collections import OrderedDict
import tqdm
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchfly.metrics import Average, MovingAverage, Speed
from torchfly.training.schedulers import WarmupWarpper
from torchfly.utilities import move_to_device

from transformers import BartModel, BartForConditionalGeneration, PretrainedConfig, AutoTokenizer
from datasets import load_metric

from memformers.models.memformerE3_base.memformer_memory import MemformerMemory
from memformers.models.memformerE3_base.modeling_memformer import MemformerForConditionalGeneration
from memformers.models.memformerE3_base.utils import get_model_config
from memformers.recurrent_training.recurrent_training_cell import RecurrentTrainingCell
from memformers.recurrent_training.recurrent_training_model import RecurrentTrainingModel
from memformers.recurrent_training.recurrent_training_flymodel import RecurrentTrainingFlyModel

def process_weights(state_dict):
    new_state_dict = OrderedDict()
    for key, value in state_dict.items():
        key = key.replace(".recurrent_training_cell.cell", "")
        new_state_dict[key] = value
    return new_state_dict


def sync_linear(a, b):
    a.weight.data.copy_(b.weight.data)
    a.bias.data.copy_(b.bias.data)


def sync_ln(a, b):
    a.weight.data.copy_(b.weight.data)
    a.bias.data.copy_(b.bias.data)


class MemformerTrainingCell(RecurrentTrainingCell):
    def __init__(self, config):
        super().__init__()
        self.model_config = PretrainedConfig.from_dict(get_model_config())
        self.model_config.dropout = 0.0
        self.model_config.activation_dropout = 0.0
        self.model_config.attention_dropout = 0.0
        self.cell = MemformerForConditionalGeneration(self.model_config)

        if config.task.pretrained_weights_path is not None:
            state_dict = torch.load(config.task.pretrained_weights_path)
            state_dict = process_weights(state_dict)
            print(self.cell.load_state_dict(state_dict, strict=False))
        else:
            self.load_pretrained_weights()

    def load_pretrained_weights(self):
        # load bart pre-trained weights
        bart_model = BartModel.from_pretrained("facebook/bart-base")
        print(self.cell.model.load_state_dict(bart_model.state_dict(), strict=False))

        sync_ln(self.cell.model.encoder.memory_layer_norm, self.cell.model.encoder.layernorm_embedding)
        # load memory state dict
        for layer in self.cell.model.encoder.layers:
            # FFN
            sync_linear(layer.mem_fc1, layer.fc1)
            sync_linear(layer.mem_fc2, layer.fc2)
            # Self attention
            sync_linear(layer.self_attn.mem_k_proj, layer.self_attn.k_proj)
            sync_linear(layer.self_attn.mem_v_proj, layer.self_attn.v_proj)
            sync_linear(layer.self_attn.mem_q_proj, layer.self_attn.q_proj)
            sync_linear(layer.self_attn.mem_out_proj, layer.self_attn.out_proj)
            # Sync Layer Norm
            sync_ln(layer.mem_self_attn_layer_norm, layer.self_attn_layer_norm)
            sync_ln(layer.mem_final_layer_norm, layer.final_layer_norm)

    def forward(self, inputs: Any, memory: MemformerMemory) -> Tuple[Any, MemformerMemory]:
        reset_signals = inputs["reset"].float()
        memory_states = memory.memory_states

        encoder_outputs = self.cell.model.encoder(
            input_ids=inputs["encoder_input_ids"],
            memory_states=memory_states,
            memory_resets=reset_signals,
            attention_mask=inputs["encoder_attention_mask"],
        )

        return (
            encoder_outputs,
            MemformerMemory(encoder_outputs.memory_states, len(reset_signals)),
        )

    def construct_memory(self, batch_size):
        return MemformerMemory(self.cell.model.construct_memory(batch_size), batch_size)


class MemformerFlyModel(RecurrentTrainingFlyModel):
    def __init__(self, config):
        super().__init__(config)
        recurrent_training_cell = MemformerTrainingCell(config)
        recurrent_training_model = RecurrentTrainingModel(recurrent_training_cell)
        self.model = recurrent_training_model
        # use -100 to avoid bugs
        self.loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
        self.tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base")
        self.eos_token_id = self.tokenizer.encode("\n\n", add_special_tokens=False)[0]
        self.metric = load_metric("squad")
        self.gts = []
        self.preds = []

    def forward(self, rollout, memory):
        model_outputs = self.model(rollout, memory, output_history_memories=True, output_rng_states=True)
        return model_outputs, model_outputs.memory

    def compute_step_loss(self, step_input, step_output, training=True):
        decoder_outputs = self.model.recurrent_training_cell.cell.model.decoder(
            input_ids=step_input["decoder_input_ids"],
            attention_mask=step_input["decoder_attention_mask"],
            encoder_hidden_states=step_output.last_hidden_state,
            encoder_attention_mask=step_output.encoder_attention_mask,
            return_dict=True,
        )
        lm_logits = F.linear(
            decoder_outputs.last_hidden_state, self.model.recurrent_training_cell.cell.model.shared.weight,
        )
        lm_loss = self.loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), step_input["target"].view(-1))

        if training:
            # log
            self.training_metrics["loss"](lm_loss.item())
            # seq_len x batch_size
            self.training_metrics["tok/s"](
                step_input["target"].shape[0] * step_input["target"].shape[1] * self.config.training.num_gpus_per_node
            )
        else:
            self.evaluation_metrics["loss"](lm_loss.item())

        return lm_loss

    def predict_step(self, batch_idx, batch, memory):
        step_output, new_memory = self(batch, memory)
        loss = self.compute_step_loss(batch[0], step_output.outputs[0], training=False)
        cell = self.model.recurrent_training_cell.cell
        outputs = cell.generate(
            encoder_outputs=step_output.outputs[0],
            decoder_start_token_id=self.tokenizer.bos_token_id,
            eos_token_id=self.eos_token_id,
            forced_eos_token_id=self.eos_token_id,
            max_length=128,
            num_beams=4,
            do_sample=False,
            return_dict_in_generate=True,
        )

        all_gt_tokens = batch[0]["decoder_input_ids"].tolist()
        all_gen_tokens = outputs.sequences.tolist()
        for idx in range(len(outputs.sequences)):
            if not batch[0]["if_empty"][idx]:
                self.gts.append(self.tokenizer.decode(all_gt_tokens[idx], skip_special_tokens=True).strip())
                self.preds.append(self.tokenizer.decode(all_gen_tokens[idx], skip_special_tokens=True).strip())

        return loss, new_memory

    def validation_loop(self, dataloader):
        # No gradient is needed for validation
        self.eval()
        self.reset_evaluation_metrics()
        with torch.no_grad():
            pbar = tqdm.tqdm(dataloader)
            pbar.mininterval = 2.0
            memory = self.construct_memory(self.config.training.evaluation.batch_size)
            for batch_idx, batch in enumerate(pbar):
                batch = move_to_device(batch, self.device)
                _, memory = self.predict_step(batch_idx, [batch], memory)

    def configure_metrics(self):
        self.training_metrics = {
            "mem_grad_std": MovingAverage(name="mem_grad_std"),
            "mem_grad_max": MovingAverage(name="mem_grad_max"),
            "loss": MovingAverage(name="loss"),
            "tok/s": Speed(),
        }
        self.evaluation_metrics = {
            "loss": Average(),
            "f1": Average(),
        }

    def get_training_metrics(self) -> Dict[str, str]:
        loss = self.training_metrics["loss"].get_metric()
        ppl = np.exp(loss)
        tok_s = self.training_metrics["tok/s"].get_metric()
        lr = self.get_last_lr()[0]
        mem_grad_std = self.training_metrics["mem_grad_std"].get_metric()
        mem_grad_max = self.training_metrics["mem_grad_max"].get_metric()

        metrics = {
            "tok/s": f"{tok_s:5.0f}",
            "lr": f"{lr:3.2e}",
            "mem_grad_std": f"{mem_grad_std:4.2e}",
            "mem_grad_max": f"{mem_grad_max:4.2e}",
            "loss": f"{loss:8.4f}",
            "ppl": f"{ppl:8.4f}",
        }

        return metrics

    def get_evaluation_metrics(self) -> Dict[str, str]:
        loss = self.evaluation_metrics["loss"].get_metric()

        preds = [{'prediction_text': text, 'id': str(idx)} for idx, text in enumerate(self.preds)]
        gts = [{'answers': {'answer_start': [100], 'text': [text]}, 'id': str(idx)} for idx, text in enumerate(self.gts)]
        results = self.metric.compute(predictions=preds, references=gts)
        # f1 = self.evaluation_metrics["f1"].get_metric()
        
        self.metric
        ppl = np.exp(loss)
        score = -ppl

        metrics = {
            "loss": f"{loss:8.4f}",
            "ppl": f"{ppl:8.4f}",
            "f1": f"{results['f1']:8.4f}",
            "score": f"{score:8.4f}",
        }

        return metrics

    def construct_memory(self, batch_size) -> MemformerMemory:
        return self.model.recurrent_training_cell.construct_memory(batch_size)

    def set_memory_params(self, memory: MemformerMemory):
        pass

    def reset_evaluation_metrics(self):
        super().reset_evaluation_metrics()
        self.gts = []
        self.preds = []