from pytorch_transformers import BertModel
from sklearn.metrics import confusion_matrix, f1_score, accuracy_score
from transformers import get_cosine_schedule_with_warmup
import torch
from torch import nn
import numpy as np
import pytorch_lightning as pl
from .model_function import *

bc_category = ["NoBC","continuer", "understanding","empathetic"]
bc_category2id = {
    "NoBC": 0,
    "continuer": 1,
    "understanding": 2,
    "negative surprise": 3,
    "positive surprise": 3,
    "request confirmation": 3,
    "affirmative": 3
}

class BackChannelModel(pl.LightningModule):
    def __init__(self, bert_path, **kwargs):
        super().__init__()
        
        # Save hyperparameters.
        self.save_hyperparameters()

        # Load the pretrained BERT model.
        self.sentence_encoder = BertModel.from_pretrained(bert_path)
        sentence_dim = self.sentence_encoder.config.hidden_size
        fusion_dim = sentence_dim

        # Get and initialize modules for selective history, acoustic features, and holistic history.
        self.selective_history, fusion_dim, self.use_selective_history = get_selective_history_module(
            self.hparams, fusion_dim, sentence_dim
        )

        self.accustic_encoder, fusion_dim, self.use_accustic = get_accustic_encoder(
            self.hparams, fusion_dim
        )

        self.holistic_history, fusion_dim, self.use_holistic_history = get_holistic_history_module(
            self.hparams, fusion_dim, sentence_dim
        )

        # Create a classifier using flexible pooling layers.
        self.classifier = nn.Sequential(
            *make_flexible_pooling_layer(
                start_dim=fusion_dim,
                inter_dim=self.hparams.classifier_pool_dim,
                last_dim=len(bc_category),
                dropout_rate = self.hparams.dropout_rate,
                is_predict_layer = True
            )
        )

        # Define the loss criterion for classification.
        self.criterion = torch.nn.CrossEntropyLoss()

        # Initialize weights and vocabulary.
        self.initialize_weights()
        self.add_speaker_vocab()

        # Reset output logs and disable automatic optimization.
        self.reset_output_logs()
        self.automatic_optimization = False

    def initialize_weights(self):
        for m in [*self.selective_history.parameters()] +\
                 [*self.classifier.parameters()]:
            if isinstance(m, nn.Linear):
                nn.init.kaiming_uniform_(m.weight)
                if(getattr(m, "bias", None) is not None):
                    nn.init.constant_(m.bias, 0)

    def reset_output_logs(self):
        self.train_losses = []
        self.outputs = {"val":{"pred":[], "target":[]},
                        "test":{"pred":[], "target":[]}}

    def add_speaker_vocab(self):
        # Expand the vocabulary of the sentence encoder model to accommodate speaker embeddings.
        origin_vocab_size = self.sentence_encoder.embeddings.word_embeddings.weight.size(0)
        self.sentence_encoder.resize_token_embeddings(origin_vocab_size+2)
    
    def forward(self, x):
        # Compute embeddings for current and history text.
        current_emb, history_emb = self.get_text_emb(x)
        fusion_emb = [current_emb]

        # Include selective history embeddings if enabled.
        if(self.use_selective_history):
            selective_emb = self.selective_history(current_emb, history_emb, x["history_size"])
            fusion_emb.append(selective_emb)

        # Include holistic history embeddings if enabled.
        if(self.use_holistic_history):
            holistic_emb = self.holistic_history(history_emb.flip(1))[1][-1]
            fusion_emb.append(holistic_emb)

        # Include acoustic embeddings if enabled.
        if(self.use_accustic):
            audio_emb = self.accustic_encoder(x["accustic_feature"])
            fusion_emb.append(audio_emb)

        # Concatenate fusion embeddings along the last dimension.
        fusion_emb = torch.concat(fusion_emb, dim=-1)

        predict = self.classifier(fusion_emb)
        return {"pred":predict}
    
    def get_text_emb(self, x):
        """
        Compute text embeddings for the current and historical text data.

        Args:
            self: Instance of the class (usually a model instance).
            x (Dict[str, Tensor]): Dictionary containing input tensors, including "token_len", "input_ids", "input_mask", and "history_size".
                                - input_ids (Tensor) : Attention mask for the input tokens. Shape: [batch_size x (history_size + 1), sequence_length]
                                - input_mask (Tensor) : Input token IDs for the sentences. Shape: [batch_size x (history_size + 1), sequence_length]
                                - token_len (Tensor) : Input token length IDs for the sentences. Shape: [batch_size x (history_size + 1), 1]
                                - history_size (Tensor) : history_size for the instance. Shape: [batch_size]
            

        Returns:
            Tuple[Tensor, Tensor]: A tuple containing the embeddings for the current text ("current_emb") and historical text ("history_emb").
                                - current_emb (Tensor): Embeddings for the current text. Shape: [batch_size, hidden_size]
                                - history_emb (Tensor): Embeddings for historical text. Shape: [batch_size, max_history_length, hidden_size]
        """
        # Get the maximum token length.
        max_token_len = torch.max(x["token_len"])

        # Obtain sentence embeddings for the input text(current + history).
        sentence_emb = get_sentence_emb(
            sentence_encoder=self.sentence_encoder, 
            token_ids=x["input_ids"][:, :max_token_len], 
            attention_mask=x["input_mask"][:, :max_token_len], 
            max_tokens=self.hparams.max_tokens
        )

        # Compute the total history size and separate current and historical embeddings.
        total_history_size = torch.sum(x["history_size"])
        current_emb = sentence_emb[total_history_size:]
        history_emb = listtomatrix(
            sentence_emb[:total_history_size], 
            np.array(x["history_size"].cpu())
        )
        return current_emb, history_emb
    
    def on_train_start(self):
        self.logger.log_hyperparams(
            self.hparams, 
            {
                f"{stage}-{metric}":0 
                    for stage in ["val", "test"]
                       for metric in ["w_f1", "f1", "acc"]
            }
        )
    
    def shared_step(self, batch, batch_idx):
        outputs = self(batch)
        loss = self.criterion(outputs["pred"], batch["target"])
        return outputs, loss
    
    def on_train_epoch_start(self):
        other_opt, bert_opt = self.optimizers()
        other_opt.zero_grad()
        bert_opt.zero_grad()
        return super().on_train_epoch_start()
    
    def training_step(self, batch, batch_idx):
        # Perform a shared training step and obtain the loss.
        _, loss = self.shared_step(batch, batch_idx)

        # Log the moving average of the training loss for the last 10 steps.
        self.train_losses.append(loss.item())
        self.log(
            f'train_loss', 
            np.mean(self.train_losses[-10:]), 
            on_step=True, on_epoch=True, prog_bar=True, logger=True
        )
    
        # Perform manual backward pass.
        self.manual_backward(loss)

        # Update model weights every 'accumulate_num' steps.
        if (batch_idx + 1) % self.hparams.accumulate_num == 0:
            other_lr, pretrain_LR = self.model_update()
            self.log('lr', other_lr, on_step=True, prog_bar=True, logger=True)
            self.log('pretrain_LR', pretrain_LR, on_step=True, prog_bar=True, logger=True)
        return loss
    
    def on_train_epoch_end(self):
        self.model_update()
        return super().on_train_epoch_end()

    def test_step(self, batch, batch_idx):
        return self.validation_step(batch, batch_idx, print_str='test')
    
    def validation_step(self, batch, batch_idx, print_str="val"):
        outputs, loss = self.shared_step(batch, batch_idx)
        self.log(
            f'{print_str}_loss', loss, 
            on_step=True, on_epoch=True, prog_bar=True, logger=True
        )
        self.outputs[print_str]["pred"].append(outputs["pred"].detach().cpu())
        self.outputs[print_str]["target"].append(batch["target"].detach().cpu())

        return {"loss":loss,}
    
    def on_test_epoch_end(self):
        logs = self.on_validation_epoch_end("test")
        return logs

    def on_validation_epoch_end(self, print_str="val"):
        # Concatenate and process prediction and target tensors.
        self.concat_output(print_str)

        # Calculate evaluation metrics.
        metrics = self.cal_metric(print_str)

        # Log the computed metrics on epoch end.
        for _metric, value in metrics.items():
            self.log(_metric, value, on_epoch=True, prog_bar=True, logger=True)

        # If it's a validation epoch, reset the prediction and target lists.
        if print_str=="val":
            self.outputs[print_str]["pred"] = []
            self.outputs[print_str]["target"] = []

        return metrics
    
    def concat_output(self, print_str):
        # Concatenate the prediction logits and find the corresponding predictions.
        pred_logit = torch.cat(self.outputs[print_str]["pred"], dim=0)
        pred = torch.max(pred_logit, 1)[1]

        # Concatenate the target tensors.
        target = torch.cat(self.outputs[print_str]["target"], dim=0)

        # Reshape and store the processed prediction and target tensors.
        self.outputs[print_str]["pred"] = pred.reshape(-1)
        self.outputs[print_str]["target"] = target.reshape(-1)

    def cal_metric(self, print_str):
        metric = {}

        # Extract predictions and targets.
        pred = self.outputs[print_str]["pred"]
        target = self.outputs[print_str]["target"]

        # Calculate accuracy, weighted F1 score, and macro F1 score.
        metric[f"{print_str}_acc"] = accuracy_score(target, pred)
        metric[f"{print_str}_w_f1"] = f1_score(target, pred, average="weighted")
        metric[f"{print_str}_f1"] = f1_score(target, pred, average="macro")

        # Calculate and display the confusion matrix.
        conf_mtrx = confusion_matrix(target, pred).T
        print()
        print(conf_mtrx)
        print()
        return metric
    
    def configure_optimizers(self):
        # Calculate the total number of optimization steps.
        total_step = self.trainer.estimated_stepping_batches \
                       *self.hparams.epochs \
                       //self.hparams.accumulate_num

        # Define parameter groups for different components(Pretrained, Ohters) of the model.
        other_params = [*self.selective_history.parameters()] +\
                       [*self.classifier.parameters()] +\
                       [*self.holistic_history.parameters()]
        
        pretrain_params = [*self.sentence_encoder.parameters()]

        if self.hparams.accustic_feature == "wav2vec":
            pretrain_params += [*self.accustic_encoder.parameters()]
        elif self.hparams.accustic_feature == "rnn":
            other_params += [*self.accustic_encoder.parameters()]

        # Create optimizers and schedulers for different parameter groups.
        other_optimizer = torch.optim.AdamW(
            other_params, 
            lr=self.hparams.learning_rate, 
            weight_decay=self.hparams.weight_decay
        )
        
        other_scheduler = get_cosine_schedule_with_warmup(
            optimizer=other_optimizer, 
            num_warmup_steps=int(total_step * 0.1), 
            num_training_steps=total_step, 
            num_cycles=1
        )

        pretrain_optimizer = torch.optim.AdamW(
            pretrain_params, 
            lr=self.hparams.pretrain_LR, 
            weight_decay=self.hparams.weight_decay
        )

        pretrain_scheduler = get_cosine_schedule_with_warmup(
            optimizer=pretrain_optimizer, 
            num_warmup_steps=int(total_step * 0.3), 
            num_training_steps=total_step, 
            num_cycles=1
        )

        # Create lists of optimizers and schedulers to return.
        opt_list = [other_optimizer, pretrain_optimizer]
        schedule_list = [other_scheduler, pretrain_scheduler]
        return opt_list, schedule_list
    
    def model_update(self):
        # Get the optimizers and learning rate schedulers.
        other_opt, pretrain_optimizer = self.optimizers()
        other_scheduler, pretrain_scheduler = self.lr_schedulers()

        # Perform optimization steps.
        other_opt.step()
        pretrain_optimizer.step()

        # Perform learning rate scheduling steps.
        other_scheduler.step()
        pretrain_scheduler.step()

        # Zero out gradients for the next iteration.
        other_opt.zero_grad()
        pretrain_optimizer.zero_grad()

        # Get and return the current learning rates.
        other_lr = other_scheduler.get_lr()[0]
        pretrain_LR = pretrain_scheduler.get_lr()[0]
        return other_lr, pretrain_LR
