import torch
import torch.nn as nn
from torchtext.data import get_tokenizer
from torchtext.vocab import GloVe
from sentence_transformers import SentenceTransformer, util
import torch.nn.functional as F
import numpy as np
from typing import Tuple
from abc import abstractmethod



DEVICE = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

class ScaffoldBilstmAttentionClassifier(nn.Module):
    def __init__(self):
        super(ScaffoldBilstmAttentionClassifier, self).__init__()

        self.glove_embedding = GloVe(name='6B', dim=300) # GloVe embedding
        self.glove_tokenizer = get_tokenizer("basic_english") # GloVe tokenizer
        self.bert_embedding = SentenceTransformer('bert-base-nli-mean-tokens') # SBERT embedding
        
        self.d1 = 768 + 300  # Embedding dimension bert + glove
        self.d2 = 256  # Hidden size of BiLSTM
        
        self.bilstm = nn.LSTM(input_size=self.d1 * 2, hidden_size=self.d2, batch_first=True, bidirectional=True) # BiLSTM
        self.attention_weights = nn.Parameter(torch.randn(self.d2 * 2, 1)) # Attention weights
        self.linear = nn.Linear(self.d2 * 2, 1024) # Linear layer

    def create_word_embeddings(self, dialogue_context):
        """
            dialogue_context: A string containing the dialogue utterance only
            returns: A tensor of shape (num_utts, d1) containing the concatenated GloVe and BERT embeddings of utterances
        """
        
        # Split dialogue context into individual words
        # words = dialogue_context.split()
        
        # TODO: Create a list of utterances from the dialogue context
        list_of_utterances = dialogue_context.split("<U>")

        # Initialize a tensor to store the word embeddings
        word_embeddings = torch.zeros(len(list_of_utterances), self.d1)  # Concatenating GloVe and BERT embeddings

        # Iterate through words and create concatenated embeddings
        for i, utterance in enumerate(list_of_utterances):
            print("GLOVE FORRRRRRRR ----- ", utterance)
            u_words = self.glove_tokenizer(utterance)
            glove_embeddings = self.glove_embedding.get_vecs_by_tokens(u_words, lower_case_backup=True) # torch.Size([words, 300])
            print("FIRSTTTT", glove_embeddings.shape)
            glove_embeddings = torch.mean(glove_embeddings, dim=0) # torch.Size([300])
            print("SECONDDD", glove_embeddings.shape)
            bert_embeddings = self.bert_embedding.encode(utterance, convert_to_tensor=True) # torch.Size([768])
            print("THIRDDDD",bert_embeddings.shape)
            # Concatenate GloVe and ELMo embeddings
            concatenated_emb = torch.cat((glove_embeddings.to(DEVICE), bert_embeddings.to(DEVICE)), dim=0)
            print("ENDDDDD",concatenated_emb.shape)

            # Store the concatenated embedding in the tensor
            word_embeddings[i] = concatenated_emb
        return word_embeddings

    def contextual_embeddings(self, word_embeddings):
        # Pass word embeddings through BiLSTM
        bilstm_out, _ = self.bilstm(word_embeddings.unsqueeze(0))  # Add batch dimension
        print("bilstm_out: ", bilstm_out.shape)

        # Concatenate forward and backward directions
        contextual_representations = torch.cat((bilstm_out[0, :, :self.d2], bilstm_out[0, :, self.d2:]), dim=1)

        return contextual_representations

    def attention(self, contextual_representations):
        attention_scores = torch.matmul(contextual_representations, self.attention_weights) # Calculate attention scores (dot product)
        attention_weights = torch.softmax(attention_scores, dim=0) # Apply softmax to obtain attention weights
        # Compute the context-rich scaffolded representation
        context_rich_representation = torch.sum(attention_weights * contextual_representations, dim=0)
        return context_rich_representation

    def forward(self, dialogue_context):
        """
            dialogue_context: A string containing the dialogue utterance only
            returns: A tensor of shape (d2) containing the context-rich scaffolded representation
        """
        word_embeddings = self.create_word_embeddings(dialogue_context)
        print("word_embeddings: ", word_embeddings.shape)
        contextual_representations = self.contextual_embeddings(word_embeddings)
        context_rich_representation = self.attention(contextual_representations)
        return context_rich_representation


print("starting model")
model = ScaffoldBilstmAttentionClassifier()
dialogue_context = "Hello, How are you<U>I really dont care<U>about the answer?"
context_rich_representation = model(dialogue_context)
print(context_rich_representation.shape)




# # 


# class SheafLearner(nn.Module):
#     """Base model that learns a sheaf from the features and the graph structure."""
#     def __init__(self):
#         super(SheafLearner, self).__init__()
#         self.L = None

#     @abstractmethod
#     def forward(self, x, edge_index):
#         raise NotImplementedError()

#     def set_L(self, weights):
#         self.L = weights.clone().detach()

# class AttentionSheafLearner(SheafLearner):
#     def __init__(self, in_channels, d):
#         super(AttentionSheafLearner, self).__init__()
#         self.d = d
#         self.linear1 = torch.nn.Linear(in_channels*2, d**2, bias=False)
    
#     def forward(self, x, edge_index):
#         """
#         x: node features
#         edge_index: graph structure
#         returns: aggregated features of shape [num_nodes, d]
#         """
#         row, col = edge_index
#         x_row = torch.index_select(x, dim=0, index=row)
#         x_col = torch.index_select(x, dim=0, index=col)
#         maps = self.linear1(torch.cat([x_row, x_col], dim=1)).view(-1, self.d, self.d)

#         id = torch.eye(self.d, device=edge_index.device, dtype=maps.dtype).unsqueeze(0)
#         maps = id - torch.softmax(maps, dim=-1)
#         aggregated_features = torch.matmul(maps, x.unsqueeze(-1))
#         aggregated_features = aggregated_features.squeeze(-1)
#         return aggregated_features



# # Example conversation
# conversation = [
#     "Hello, how are you?",
#     "I'm good, thanks! How about you?",
#     "I'm doing well, too.",
#     "Glad to hear that."
# ]

# # Step 1: Convert utterances to feature vectors
# # vectorizer = TfidfVectorizer()
# # X_tfidf = vectorizer.fit_transform(conversation).toarray()
# X_tfidf = torch.randn(4,10)  # random input tensor (node, features)


# # Step 2: Create a simple sequential graph
# num_utterances = len(conversation)
# edges = [(i, i + 1) for i in range(num_utterances - 1)]
# edges.append((num_utterances - 1, num_utterances - 1, )) # self-loop
# # Step 3: Prepare x and edge_index for the AttentionSheafLearner
# x = torch.tensor(X_tfidf, dtype=torch.float)
# edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()

# in_channels = 10  # Number of features in the input
# d = 10  # Dimension of the sheaf || THIS IS SUPPOSED TO BE THE SAME AS THE INPUT FEATURES

# print("X Shape: ", x)  
# print("EI Shape: ",edge_index.shape)
# attention_learner = AttentionSheafLearner(in_channels, d)
# feat = attention_learner(x, edge_index)
# print(feat.shape)
# print(feat)




# # from your_module import AttentionSheafLearner  # Import the AttentionSheafLearner from your module

# # # Define the input parameters
# # in_channels = 4  # Number of input channels or features
# # d = 5  # Dimension

# # # Create a random input tensor (node features)
# # x = torch.randn(7, in_channels)  # Example: 5 nodes with 4 input features (5 x 4)

# # print("x: ", x.shape)

# # # Create a random edge index (assuming a simple undirected graph)
# # edge_index = torch.tensor([[0, 1, 2, 3, 4, 2, 4], [1, 0, 3, 2, 4, 2, 4]], dtype=torch.long)

# # print("edge_index: ", edge_index.shape)

# # # Initialize the AttentionSheafLearner
# # attention_learner = AttentionSheafLearner(in_channels, d)

# # # Pass the input and edge index through the AttentionSheafLearner
# # feat = attention_learner(x, edge_index)
# # print("aggregated_features: ", feat.shape)


# # print(x)
# # print("= = "*5)
# # print(feat)








# # import torch
# # from sklearn.feature_extraction.text import TfidfVectorizer
# # import numpy as np