import torch
import torch.nn.functional as F
import numpy as np
from typing import Tuple
from abc import abstractmethod
from torch import nn
import config


class SheafLearner(nn.Module):
    """Base model that learns a sheaf from the features and the graph structure."""
    def __init__(self):
        super(SheafLearner, self).__init__()
        self.L = None

    @abstractmethod
    def forward(self, x, edge_index):
        raise NotImplementedError()

    def set_L(self, weights):
        self.L = weights.clone().detach()

class LocalConcatSheafLearnerVariant(SheafLearner):
    """Learns a sheaf by concatenating the local node features and passing them through a linear layer + activation."""

    def __init__(self, d: int, hidden_channels: int, out_shape: Tuple[int, ...]):
        super(LocalConcatSheafLearnerVariant, self).__init__()
        assert len(out_shape) in [1, 2]
        self.out_shape = out_shape
        self.d = d
        self.hidden_channels = hidden_channels
        self.linear1 = torch.nn.Linear(hidden_channels, int(np.prod(out_shape)), bias=False)
        self.linear1.to(config.DEVICE)
        self.act = torch.tanh
        
    def forward(self, x, edge_index):
        # print("x: ", x.shape)   # [num_nodes, hidden_channels]
        x, edge_index = x.to(config.DEVICE), edge_index.to(config.DEVICE)
        row, col = edge_index
        x_row = torch.index_select(x, dim=0, index=row)
        x_col = torch.index_select(x, dim=0, index=col)
        x_cat = torch.cat([x_row, x_col], dim=-1)
        # print("x_cat: ", x_cat.shape)   # [num_edges, 2*hidden_channels]
        x_cat = x_cat.reshape(-1, self.d, self.hidden_channels).sum(dim=1)
        # print("x_cat: ", x_cat.shape)   # [num_edges, hidden_channels]
        # print(f"CUDA CHECK: {x_cat.is_cuda}, {self.linear1.weight.is_cuda}")
        x_cat = self.linear1(x_cat)

        # print("x_cat: ", x_cat.shape)   # [num_edges, prod(out_shape)]
        maps = self.act(x_cat)

        if len(self.out_shape) == 2:
            maps = maps.view(-1, self.out_shape[0], self.out_shape[1])
        else:
            maps = maps.view(-1, self.out_shape[0])
        
        pooled = output_tensor = torch.mean(maps, dim=0) # turn into a linear tensor
        # print("pooled: ", pooled.shape)   # [num_edges, prod(out_shape)]
        return pooled

# # Example node features and edge index
# num_nodes = 4  # Number of nodes in the graph
# feature_length = 66  # Feature length of each node

# x = torch.randn(num_nodes, feature_length)
# edges = [(i, i + 1) for i in range(num_nodes - 1)]
# edges.append((num_nodes - 1, num_nodes - 1, ))  # self-loop
# edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()

# # Parameters for LocalConcatSheafLearnerVariant
# d = 2
# hidden_channels = feature_length
# out_shape = (100,)  # Example output shape

# learner = LocalConcatSheafLearnerVariant(d, hidden_channels, out_shape)
# output = learner(x, edge_index)





# 

# class AttentionSheafLearner(SheafLearner):
#     def __init__(self, in_channels, d):
#         super(AttentionSheafLearner, self).__init__()
#         self.d = d
#         self.linear1 = torch.nn.Linear(in_channels*2, d**2, bias=False)
    
#     def forward(self, x, edge_index):
#         """
#         x: node features
#         edge_index: graph structure
#         returns: aggregated features of shape [num_nodes, d]
#         """
#         row, col = edge_index
#         x_row = torch.index_select(x, dim=0, index=row)
#         x_col = torch.index_select(x, dim=0, index=col)
#         maps = self.linear1(torch.cat([x_row, x_col], dim=1)).view(-1, self.d, self.d)

#         id = torch.eye(self.d, device=edge_index.device, dtype=maps.dtype).unsqueeze(0)
#         maps = id - torch.softmax(maps, dim=-1)
#         aggregated_features = torch.matmul(maps, x.unsqueeze(-1))
#         aggregated_features = aggregated_features.squeeze(-1)
#         return aggregated_features



# # Example conversation
# conversation = [
#     "Hello, how are you?",
#     "I'm good, thanks! How about you?",
#     "I'm doing well, too.",
#     "Glad to hear that."
# ]

# X_tfidf = torch.randn(4,10)  # random input tensor (node, features)


# # Step 2: Create a simple sequential graph
# num_utterances = len(conversation)
# edges = [(i, i + 1) for i in range(num_utterances - 1)]
# edges.append((num_utterances - 1, num_utterances - 1, )) # self-loop
# # Step 3: Prepare x and edge_index for the AttentionSheafLearner
# x = torch.tensor(X_tfidf, dtype=torch.float)
# edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()

# in_channels = 10  # Number of features in the input
# d = 10  # Dimension of the sheaf || THIS IS SUPPOSED TO BE THE SAME AS THE INPUT FEATURES

# print("X Shape: ", x.shape)  
# print("EI Shape: ",edge_index.shape)
# attention_learner = AttentionSheafLearner(in_channels, d)
# feat = attention_learner(x, edge_index)
# print(feat.shape)
# print(feat)



