import torch.nn.functional as F
import torch.nn as nn
import torch
import torch.utils.data as data
import torch.optim as optim
from Utils.utils import *

class GCN_layer(nn.Module):
    def __init__(self, in_features, out_features, A):
        super(GCN_layer, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.A = A
        self.fc = nn.Linear(in_features, out_features)
        
    def forward(self, X):
        return self.fc(torch.spmm(self.A, X))
    
class Expert(nn.Module):
    def __init__(self, dims):
        super(Expert, self).__init__()

        self.mlp = nn.Sequential(nn.Linear(dims[0], dims[1]), nn.ReLU(), nn.Linear(dims[1], dims[2]))

    def forward(self, x):
        return self.mlp(x)

class Indexing_network(nn.Module):
    def __init__(self, class_emb, phrase_emb, A, M):
        super(Indexing_network, self).__init__()
        
        self.class_emb = nn.Parameter(class_emb, requires_grad=False)
        self.phrase_emb = nn.Parameter(phrase_emb, requires_grad=False)
        self.A = nn.Parameter(A, requires_grad=False)
        
        self.M = M
        self.GNN = nn.Sequential(GCN_layer(768, 768, self.A), nn.ReLU(), GCN_layer(768, 768, self.A))        
        self.experts = nn.ModuleList([Expert([768,768,768,768]) for i in range(self.M)])
        self.g1 = nn.Sequential(nn.Linear(768, 3), nn.Softmax(dim=1))
        self.g2 = nn.Sequential(nn.Linear(768, 3), nn.Softmax(dim=1))

    def forward(self, batch_X):
        class_emb = self.GNN(self.class_emb)
        
        expert_outputs = [self.experts[i](batch_X).unsqueeze(-1) for i in range(self.M)]
        g1_dist, g2_dist = self.g1(batch_X), self.g2(batch_X)
        g1_dist, g2_dist = torch.unsqueeze(g1_dist, 1), torch.unsqueeze(g2_dist, 1)
        
        g1_dist = g1_dist.repeat(1, 768, 1)
        g2_dist = g2_dist.repeat(1, 768, 1)
        
        expert_outputs = torch.cat(expert_outputs, -1)
        expert_output1 = (expert_outputs * g1_dist).sum(2)
        expert_output2 = (expert_outputs * g2_dist).sum(2)
        
        class_output = torch.mm(expert_output1, class_emb.T)
        phrase_output = torch.mm(expert_output2, self.phrase_emb.T)
        return class_output, phrase_output, expert_output1, expert_output2

class TSI_module(nn.Module):
    def __init__(self, class_emb, phrase_emb, A, M, indexing_network_path, lambda_IL):
        super(TSI_module, self).__init__()
        self.CLF = Indexing_network(class_emb, phrase_emb, A, M)
        self.CLF.load_state_dict(torch.load(indexing_network_path))
        
        for param in self.CLF.GNN.parameters():
            param.requires_grad = False
    
        self.adapter = nn.Sequential(nn.Linear(768 * 2, int(768 * 1.5)), nn.ReLU(), nn.Linear(int(768 * 1.5), 768))
        self.fusion_module = nn.Sequential(nn.Linear(768, 1), nn.Sigmoid())
        self.alpha = nn.Parameter(torch.ones(1) * 1.)
        self.CLF_loss = nn.CrossEntropyLoss()
        self.lambda_IL = lambda_IL
    
    # We follow the implementation of MultipleNegativesRankingLoss in SentenceTransformer framework
    def fine_tune_loss(self, reps):
        embeddings_a = reps[0]
        embeddings_b = torch.cat(reps[1:])

        scores = dot_score(embeddings_a, embeddings_b)
        labels = torch.tensor(
            range(len(scores)), dtype=torch.long, device=scores.device
        )
        return self.CLF_loss(scores, labels)
    
    def forward(self, x_emb):
        x_class_output, x_phrase_output, x_expert_output1, x_expert_output2 = self.CLF(x_emb)
        x_expert = self.adapter(torch.cat([x_expert_output1, x_expert_output2], -1))
        weighted_x_emb = x_emb + x_expert * self.alpha * self.fusion_module(x_emb)
        
        return weighted_x_emb, x_class_output, x_phrase_output
        
    def get_fine_tune_loss(self, q_emb, d_emb, neg_d_emb, batch_Y, batch_Y2): 
        
        weighted_q_emb, q_class_output, q_phrase_output = self.forward(q_emb)
        weighted_d_emb, d_class_output, d_phrase_output = self.forward(d_emb)
        weighted_neg_d_emb, _, _ = self.forward(neg_d_emb)
  
        total_fine_tune_loss = self.fine_tune_loss([weighted_q_emb, weighted_d_emb, weighted_neg_d_emb])
    
        total_IL_loss = self.CLF_loss(q_class_output, batch_Y) + self.CLF_loss(q_phrase_output, batch_Y2) + \
            self.CLF_loss(d_class_output, batch_Y) + self.CLF_loss(d_phrase_output, batch_Y2)
        
        return total_fine_tune_loss + total_IL_loss * self.lambda_IL