import torch
import torch.nn as nn
from transformers import AutoModel
from config import LOG
import config
from torchtext.data import get_tokenizer
from torchtext.vocab import GloVe
from sentence_transformers import SentenceTransformer, util
import torch.nn.functional as F
import numpy as np
from typing import Tuple
from abc import abstractmethod
from models.sheaf import LocalConcatSheafLearnerVariant
from config import dialogue_map, phq_lexicon
from torch.nn import CrossEntropyLoss
from evaluate import load
from transformers import PretrainedConfig



class RotatingAttention(nn.Module):
    def __init__(self, decoder_hidden_size, knowledge_size):
        super(RotatingAttention, self).__init__()

        # Attention mechanism parameters
        self.W_d = nn.Linear(decoder_hidden_size, decoder_hidden_size)
        self.W_k = nn.Linear(knowledge_size, decoder_hidden_size)
        self.v = nn.Linear(decoder_hidden_size, 1)

    def forward(self, query, key, value):
        # Calculate attention scores
        attn_scores = self.v(torch.tanh(self.W_d(query) + self.W_k(key)))

        # Apply softmax to get attention weights
        attn_weights = F.softmax(attn_scores, dim=1)

        # Compute weighted knowledge representation
        weighted_knowledge = attn_weights * query  # Same size as query due to element-wise multiplication

        return weighted_knowledge


class MyConfig(PretrainedConfig):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.d = 2
        self.sheaf_out_shape = (4096,)


class DialogueSummarizationModel(PretrainedConfig):
    def __init__(self, **kwargs):
        # super(DialogueSummarizationModel, self).__init__()
        super(DialogueSummarizationModel, self).__init__(**kwargs)
        self.model = config.modelCheckpoint
        # if LOG: print("Model HF Device Map: ", self.model.hf_device_map)
        # print(self.model)
        # Freeze all parameters
        for param in self.model.parameters():
            param.requires_grad = False
        # Very Specific to what I am doing: Unfreeze the last 2 layers + norm layer + last linear layer
        # for idx, each in enumerate(self.model.parameters()): 
        #     if idx>280: each.requires_grad = True
        # for param in self.model.model.layers[30].parameters(): param.requires_grad = True
        for param in self.model.model.layers[31].parameters(): param.requires_grad = True
        for param in self.model.model.norm.parameters(): param.requires_grad = True

        llama_hidden_size = 4096
        llama_vocab_size = 32000
        self.lm_head = nn.Linear(llama_hidden_size, llama_vocab_size, bias=False)
        
        # Define Sheaf Parameters
        self.d = 2
        self.sheaf_out_shape = (4096,)  # Example sheaf's output shape
        
        # Scaffold Parameters
        self.d1 = 768 + 300  # Embedding dimension bert + glove
        self.d2 = 2048  # Hidden size of BiLSTM
        self.bilstm = nn.LSTM(input_size=self.d1, hidden_size=self.d2, batch_first=True, bidirectional=True) # BiLSTM
        self.attention_weights = nn.Parameter(torch.randn(self.d2 * 2, 1)) # Attention weights
        # self.linear = nn.Linear(self.d2 * 2, 1024) # Linear layer

        self.glove_embedding = GloVe(name='6B', dim=300) # GloVe embedding
        self.glove_size = 300
        self.glove_tokenizer = get_tokenizer("basic_english") # GloVe tokenizer
        self.bert_embedding = SentenceTransformer('bert-base-nli-mean-tokens') # SBERT embedding

        self.rotating_attention = RotatingAttention(decoder_hidden_size=4096, knowledge_size=4096)
        self.transformation_layer = nn.Linear(4096, 2048)

        self.loss_fct = CrossEntropyLoss()        
        self.bertscore = load("bertscore")
        
        self.threshold = 0.5
        self.lexi_sent = []
        for each in phq_lexicon.keys():
            self.lexi_sent.append(" ".join(phq_lexicon[each]))

        
    def phq_filtering(self, dialogue):    
        
        list_of_utterances = dialogue.split("<U>")
        scores = []
        filtered_dialogue = []
        for utterance in list_of_utterances:
            utts = [utterance.strip()]*10
            results = self.bertscore.compute(predictions=utts, references=self.lexi_sent, lang="en", model_type="distilbert-base-uncased")
            f1_sum = sum(results['f1'])
            scores.append(results['f1'])
            if f1_sum > self.threshold:
                filtered_dialogue.append(utterance)
            
        knowledge_retention = len(filtered_dialogue)/len(list_of_utterances) * 100
        if LOG: print(f"Knowledge Retention: {knowledge_retention}%")
        return filtered_dialogue, scores


    def forward(self, input_ids, attention_mask, labels, dialogue_IDx):
        
        # Example Dialogue="Hello, How are you<U>I really dont care<U>about the answer?"
        dialogue = dialogue_map[str(dialogue_IDx)]['utterances']
        bert, _glove = self.create_word_embeddings(dialogue, glove_out=False)
        
        cc_dialogue = dialogue_map[str(dialogue_IDx)]['cc_utterances']
        know_dialogue = self.phq_filtering(cc_dialogue)
        cc_bert, cc_glove = self.create_word_embeddings(cc_dialogue)
        
        # Scaffold Model
        context_rich_representation = self.scaffolding(cc_bert, cc_glove)
        if LOG: print("Scaffolded Representations: ",context_rich_representation.shape)

        # Define the graph and sheaf learner
        sheaf_out = self.sheaf_learner(dialogue, bert)
        if LOG: print("Sheaf Output: ", sheaf_out.shape)
        
        outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        if LOG: print("My own outputs shape: ", outputs.hidden_states[-1].shape)
        
        # Rotating Attention Parameters
        Q = outputs.hidden_states[-1]
        K = sheaf_out.expand(1, outputs.hidden_states[-1].shape[2])
        V = context_rich_representation.expand(1, outputs.hidden_states[-1].shape[2])
        # print("Shapes: ", Q.shape, K.shape, V.shape)

        # Create RotatingAttention module
        rotation1 = self.rotating_attention(outputs.hidden_states[-1], sheaf_out, context_rich_representation)
        if LOG: print("Rotation 1: ", rotation1.shape)
        rotation2 = self.rotating_attention(outputs.hidden_states[-1], context_rich_representation, sheaf_out)
        if LOG: print("Rotation 2: ", rotation2.shape)  
        
        # Concatenate the two rotations
        concat = torch.cat((rotation1, rotation2), dim=1)
        if LOG: print("Concatenation: ", concat.shape)

        logits = self.lm_head(concat)
        logits = logits.float()
        if LOG: print("My own logits shape: ", logits.shape)
        if LOG: print("Model logits shape: ", outputs.logits.shape)

        reshaped_logits = logits.view(1, -1, 4096)
        reshaped_logits = self.transformation_layer(reshaped_logits)
        logits = reshaped_logits.view(1, 2048, 32000)
        logits = logits.float()

        loss = None
        if labels is not None:
            # Shift so that tokens < n predict n
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous() # Flatten the tokens
            shift_logits = shift_logits.view(-1, 32000)
            shift_labels = shift_labels.view(-1)
            shift_labels = shift_labels.to(shift_logits.device) # Enable model parallelism
            loss = self.loss_fct(shift_logits, shift_labels)
        
        outputs = {
                    "loss": loss, 
                    "logits": logits, 
                    "last_hidden_state": outputs.hidden_states[-1]
                }

        return outputs

    def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs):
        if past_key_values is not None:
            past_length = past_key_values[0][0].shape[2]

            # Some generation methods already pass only the last input ID
            if input_ids.shape[1] > past_length:
                remove_prefix_length = past_length
            else:
                # Default to old behavior: keep only final ID
                remove_prefix_length = input_ids.shape[1] - 1

            input_ids = input_ids[:, remove_prefix_length:]

        
        position_ids = kwargs.get("position_ids", None)
        if attention_mask is not None and position_ids is None:
            # create position_ids on the fly for batch generation
            position_ids = attention_mask.long().cumsum(-1) - 1
            position_ids.masked_fill_(attention_mask == 0, 1)
            if past_key_values:
                position_ids = position_ids[:, -input_ids.shape[1] :]

        # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
        if inputs_embeds is not None and past_key_values is None:
            model_inputs = {"inputs_embeds": inputs_embeds}
        else:
            model_inputs = {"input_ids": input_ids}

        model_inputs.update(
            {
                "position_ids": position_ids,
                "past_key_values": past_key_values,
                "use_cache": kwargs.get("use_cache"),
                "attention_mask": attention_mask,
            }
        )
        return model_inputs


    # ========= SHEAF =========
    def sheaf_learner(self, dialogue, bert):
        # Graph Construction
        list_of_utterances = dialogue.split("<U>")
        num_nodes = len(list_of_utterances)  # Number of nodes in the graph
        
        edges = [(i, i + 1) for i in range(num_nodes - 1)]
        edges.append((num_nodes - 1, num_nodes - 1, ))  # self-loop
        edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
        
        # Sheaf Learner
        feature_length = bert.shape[1]  # Feature length of each node
        learner = LocalConcatSheafLearnerVariant(d=self.d, hidden_channels=feature_length , out_shape=self.sheaf_out_shape)
        output = learner(bert, edge_index)
        
        return output
    




    # ========= SCAFFOLDING =========
    def scaffold_attention(self, contextual_representations):
        attention_scores = torch.matmul(contextual_representations, self.attention_weights) # Calculate attention scores (dot product)
        attention_weights = torch.softmax(attention_scores, dim=0) # Apply softmax to obtain attention weights
        # Compute the context-rich scaffolded representation
        context_rich_representation = torch.sum(attention_weights * contextual_representations, dim=0)
        return context_rich_representation

    def create_word_embeddings(self, dialogue, glove_out=True, bert_out=True):
        """
            dialogue_context: A string containing the dialogue utterance only
            returns: A tensor of shape (num_utts, d1) containing the concatenated GloVe and BERT embeddings of utterances
        """
        
        list_of_utterances = dialogue.split("<U>")
        if bert_out:
            bert_embeddings = self.bert_embedding.encode(list_of_utterances, convert_to_tensor=True) # torch.Size([768])
            feature_length = bert_embeddings.shape[1]  # Feature length of each node
        else: 
            bert_embeddings = None

        if glove_out:
            glove_word_embeddings = torch.zeros(len(list_of_utterances), self.glove_size)  # Blank tensor for GloVe embeddings
            # Iterate through words and create concatenated embeddings
            for i, utterance in enumerate(list_of_utterances):
                u_words = self.glove_tokenizer(utterance)
                glove_embeddings = self.glove_embedding.get_vecs_by_tokens(u_words, lower_case_backup=True) # torch.Size([words, 300])
                glove_embeddings = torch.mean(glove_embeddings, dim=0) # torch.Size([300])
                # Store the concatenated embedding in the tensor
                glove_word_embeddings[i] = glove_embeddings
        else:
            glove_word_embeddings = None

        return bert_embeddings, glove_word_embeddings

    def scaffolding(self, bert, glove):
        # bert, glove = self.create_word_embeddings(dialogue_context)
        word_embeddings = torch.cat((bert.to(config.DEVICE), glove.to(config.DEVICE)), dim=1)
        bilstm_out, _ = self.bilstm(word_embeddings.unsqueeze(0))  # Pass word embeddings through BiLSTM
        contextual_representations = torch.cat((bilstm_out[0, :, :self.d2], bilstm_out[0, :, self.d2:]), dim=1) # Concatenate forward and backward directions
        context_rich_representation = self.scaffold_attention(contextual_representations)
        return context_rich_representation


