from typing import Any, Dict, Tuple, List, Union
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchfly.metrics import Average, MovingAverage, Speed
from torchfly.training.schedulers import WarmupWarpper, WarmupLinearSchedule
from transformers import BartModel, PretrainedConfig

from .modeling_memformer import MemformerModel
from .memformer_memory import MemformerMemory
from .utils import get_model_config
from ...recurrent_training.recurrent_training_cell import RecurrentTrainingCell
from ...recurrent_training.recurrent_training_model import RecurrentTrainingModel
from ...recurrent_training.recurrent_training_flymodel import RecurrentTrainingFlyModel

def sync_linear(a, b):
    a.weight.data.copy_(b.weight.data)
    a.bias.data.copy_(b.bias.data)

def sync_ln(a, b):
    a.weight.data.copy_(b.weight.data)
    a.bias.data.copy_(b.bias.data)

class MemformerTrainingCell(RecurrentTrainingCell):
    def __init__(self, config):
        super().__init__()
        self.model_config = PretrainedConfig.from_dict(get_model_config())
        self.cell = MemformerModel(self.model_config)

        self.load_pretrained_weights()

        self.cell.shared.weight.requires_grad = False
        self.cell.encoder.embed_positions.weight.requires_grad = False
        self.cell.decoder.embed_positions.weight.requires_grad = False

    def load_pretrained_weights(self):
        # load bart pre-trained weights
        bart_model = BartModel.from_pretrained("facebook/bart-base")
        print(self.cell.load_state_dict(bart_model.state_dict(), strict=False))        

        sample_word_indices = np.random.randint(2, 5000, (self.model_config.memory_len - 1))
        sample_word_indices = np.sort(sample_word_indices)
        sample_word_indices = np.append([0], sample_word_indices)
        self.cell.encoder.memory_bias.data = bart_model.shared.weight[sample_word_indices] + \
            (torch.randn(self.model_config.memory_len, self.model_config.d_model) * 0.002)

        sync_ln(self.cell.encoder.memory_layer_norm, self.cell.encoder.layernorm_embedding)
        # load memory state dict
        for layer in self.cell.encoder.layers:
            # FFN
            sync_linear(layer.mem_fc1, layer.fc1)
            sync_linear(layer.mem_fc2, layer.fc2)
            # Self attention
            sync_linear(layer.self_attn.mem_k_proj, layer.self_attn.k_proj)
            sync_linear(layer.self_attn.mem_v_proj, layer.self_attn.v_proj)
            sync_linear(layer.self_attn.mem_q_proj, layer.self_attn.q_proj)
            sync_linear(layer.self_attn.mem_out_proj, layer.self_attn.out_proj)
            # Sync Layer Norm
            sync_ln(layer.mem_self_attn_layer_norm, layer.self_attn_layer_norm)
            sync_ln(layer.mem_final_layer_norm, layer.final_layer_norm)



    def forward(self, inputs: Any, memory: MemformerMemory) -> Tuple[Any, MemformerMemory]:
        memory_states = memory.memory_states

        encoder_outputs = self.cell.encoder(
            input_ids=inputs["encoder_input_ids"],
            memory_states=memory_states,
            memory_resets=inputs["reset"],
            attention_mask=inputs["encoder_attention_mask"],
        )

        return (
            encoder_outputs,
            MemformerMemory(encoder_outputs.memory_states, len(inputs["reset"])),
        )

    def construct_memory(self, batch_size):
        return MemformerMemory(self.cell.construct_memory(batch_size), batch_size)


class MemformerFlyModel(RecurrentTrainingFlyModel):
    def __init__(self, config):
        super().__init__(config)
        recurrent_training_cell = MemformerTrainingCell(config)
        recurrent_training_model = RecurrentTrainingModel(recurrent_training_cell)
        self.model = recurrent_training_model
        self.loss_fct = nn.CrossEntropyLoss(ignore_index=-100, label_smoothing=config.task.label_smoothing)

    def forward(self, rollout, memory):
        model_outputs = self.model(rollout, memory, output_history_memories=True, output_rng_states=True)
        return model_outputs, model_outputs.memory

    def compute_step_loss(self, step_input, step_output):
        decoder_outputs = self.model.recurrent_training_cell.cell.decoder(
            input_ids=step_input["decoder_input_ids"],
            attention_mask=step_input["decoder_attention_mask"],
            encoder_hidden_states=step_output.last_hidden_state,
            encoder_attention_mask=step_output.encoder_attention_mask,
            return_dict=True,
        )
        lm_logits = F.linear(decoder_outputs.last_hidden_state, self.model.recurrent_training_cell.cell.shared.weight,)
        masked_lm_loss = self.loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), step_input["target"].view(-1))

        # log
        self.training_metrics["loss"](masked_lm_loss.item())
        # seq_len x batch_size
        self.training_metrics["tok/s"](
            step_input["target"].shape[0] * step_input["target"].shape[1] * self.config.training.num_gpus_per_node
        )
        init_memory_states = self.model.recurrent_training_cell.cell.encoder.memory_layer_norm(self.model.recurrent_training_cell.cell.encoder.memory_bias)
        init_gate = self.model.recurrent_training_cell.cell.encoder.memory_gate_net(init_memory_states)
        
        return masked_lm_loss + (init_gate**2).mean()

    def configure_metrics(self):
        self.training_metrics = {
            "mem_grad_std": MovingAverage(name="mem_grad_std"),
            "mem_grad_max": MovingAverage(name="mem_grad_max"),
            "loss": MovingAverage(name="loss"),
            "tok/s": Speed(),
        }
        self.evaluation_metrics = {"loss": Average()}

    def get_training_metrics(self) -> Dict[str, str]:
        loss = self.training_metrics["loss"].get_metric()
        ppl = np.exp(loss)
        tok_s = self.training_metrics["tok/s"].get_metric()
        lr = self.get_last_lr()[1]
        mem_grad_std = self.training_metrics["mem_grad_std"].get_metric()
        mem_grad_max = self.training_metrics["mem_grad_max"].get_metric()

        metrics = {
            "tok/s": f"{tok_s:5.0f}",
            "lr": f"{lr:3.2e}",
            "mem_grad_std": f"{mem_grad_std:4.2e}",
            "mem_grad_max": f"{mem_grad_max:4.2e}",
            "loss": f"{loss:8.4f}",
            "ppl": f"{ppl:8.4f}",
        }

        return metrics

    def construct_memory(self, batch_size) -> MemformerMemory:
        return self.model.construct_memory(batch_size)

    def set_memory_params(self, memory: MemformerMemory):
        pass

    def configure_optimizers(self, config, total_num_update_steps) -> Union[List, List]:
        # optimizer_grouped_parameters = self.model.parameters()

        memory_module_parameters = []
        pretrained_parameters = []

        memory_module_names = []
        pretrained_names = []

        for name, param in self.model.named_parameters():
            if "mem" or "decoder" in name:
                memory_module_names.append(name)
                memory_module_parameters.append(param)
            else:
                pretrained_names.append(name)
                pretrained_parameters.append(param)

        optimizer_grouped_parameters = [
            {
                "params": pretrained_parameters,
                "lr": config.optimization.learning_rate,
                "weight_decay": config.optimization.weight_decay
            },
            {
                "params": memory_module_parameters,
                "lr": config.optimization.memory_learning_rate,
                "weight_decay": config.optimization.memory_weight_decay
            },
        ]

        betas = config.optimization.get("betas", (0.9, 0.999))
        warmup_steps = config.scheduler.warmup_steps

        optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=1000, betas=betas)

        scheduler = scheduler = WarmupLinearSchedule(optimizer, warmup_steps, total_num_update_steps)

        self.get_last_lr = scheduler.get_last_lr

        return [optimizer], [scheduler]