import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModel

import numpy as np
from torch_geometric.nn import GCNConv
from torch_geometric.data import Data, Batch
import torch.nn.functional as F


class Config(object):

    def __init__(self, dataset, args):
        self.model_name = args.model_name

        self.current_dataset = args.current_dataset
        self.train_path = dataset + self.current_dataset + '/train.txt'
        self.dev_path = dataset + self.current_dataset + '/dev.txt'
        self.test_path = dataset + self.current_dataset + '/test.txt'
        self.class_list = [x.strip() for x in open(
            dataset + self.current_dataset + '/class.txt').readlines()]
        self.save_path = dataset + '/saved_dict/' + self.model_name + '.ckpt'
        self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

        self.require_improvement = args.require_improvement
        self.num_classes = len(self.class_list)
        self.num_epochs = args.num_epochs
        self.batch_size = args.batch_size
        self.pad_size = args.pad_size
        self.learning_rate = args.learning_rate
        self.model_path = args.model_path
        self.tokenizer = AutoTokenizer.from_pretrained(args.model_path)
        self.hidden_size1 = args.hidden_size1
        self.hidden_size2 = args.hidden_size2
        self.hidden_size3 = args.hidden_size3
        self.hidden_size4 = args.hidden_size4

        self.dropout = args.dropout
        self.dropout2 = args.dropout2
        self.num_clusters = args.num_cluserts
        self.st = args.st
        self.reducer_dims = args.reducer_dims
        self.clusters_path = args.clusters_path


class Model(nn.Module):
    def __init__(self, config):
        super(Model, self).__init__()
        self.pretrained_model = AutoModel.from_pretrained(config.model_path)
        self.d = nn.Parameter(torch.tensor([0.3])).to(config.device)
        self.dropout = nn.Dropout(config.dropout)
        self.dropout2 = nn.Dropout(config.dropout2)
        self.num_clusters = config.num_clusters

        self.reducer_dims = config.reducer_dims

        self.fc1 = nn.Linear(config.hidden_size1 * self.reducer_dims, 1000)
        self.fc2 = nn.Linear(1000, 100)
        self.fc3 = nn.Linear(100, config.num_classes)

        self.landmarks_initialized = False
        self.landmarks = None
        self.device = config.device
        self.batch_size = config.batch_size
        self.gcn1 = GCNConv(config.hidden_size1, config.hidden_size2)
        self.gcn2 = GCNConv(config.hidden_size2, config.hidden_size3)
        self.gcn3 = GCNConv(config.hidden_size3, config.hidden_size4)
        self.padsize = config.pad_size

        self.clusters_ = config.clusters_path
        self.alpha_for_distanceweight = nn.Parameter(torch.tensor([0.5])).to(config.device)

        self.reducer = nn.Linear(self.num_clusters, self.reducer_dims)

        self.num_clusters = config.num_clusters

        self.st = config.st

    def initialize_landmarks(self):

        initial_landmarks = np.load(self.clusters_)
        self.landmarks = nn.Parameter(torch.tensor(initial_landmarks, dtype=torch.float32))
        self.landmarks_initialized = True

    def distance_matrix(self, data, clusters):

        data_expanded = data.unsqueeze(2)  # shape: (b, n, 1, d)
        clusters_expanded = clusters.unsqueeze(0).unsqueeze(0)  # shape: (1, 1, k, d)
        data_expanded = data_expanded.to(self.device)
        clusters_expanded = clusters_expanded.to(self.device)
        diff_square = (data_expanded - clusters_expanded) ** 2

        distance = torch.sqrt(torch.sum(diff_square, dim=-1))

        min_vals = distance.min(dim=-1, keepdim=True)[0]
        max_vals = distance.max(dim=-1, keepdim=True)[0]

        epsilon = 1e-6
        normalized_distance = (distance - min_vals) / (max_vals - min_vals + epsilon)
        # b*n*k
        return normalized_distance

    def probability_matrix(self, data, pad_mask):
        similarity_matrix = torch.exp(- (data ** 2) / (2 * self.d ** 2))

        expanded_pad_mask = pad_mask.unsqueeze(-1).expand_as(similarity_matrix)

        similarity_matrix = similarity_matrix * expanded_pad_mask

        return similarity_matrix

    def compute_distance_weight(self):
        n = self.padsize
        indices = torch.arange(n).unsqueeze(0)
        diff = torch.abs(indices - indices.T)
        smooth_matrix = torch.exp(-self.alpha_for_distanceweight * diff.float().to(self.device))
        mask = torch.eye(n, dtype=bool, device=self.device)
        smooth_matrix = torch.where(mask, torch.zeros_like(smooth_matrix), smooth_matrix)
        return smooth_matrix

    def gcn_batched(self, data, weights):
        data_ori = data
        B, K, D = data.shape
        graphs = []
        for i in range(B):
            x = data[i]
            edge_index = weights[i].nonzero(as_tuple=False).t().contiguous()
            graphs.append(Data(x=x, edge_index=edge_index))

        batched_data = Batch.from_data_list(graphs)

        out = self.gcn1(batched_data.x, batched_data.edge_index)
        out = F.relu(out)
        out = self.dropout(out)

        out = self.gcn2(out, batched_data.edge_index)
        out = F.relu(out)
        out = self.dropout(out)

        out = self.gcn3(out, batched_data.edge_index)
        out = F.relu(out)
        out = self.dropout(out)

        out = out.view(B, K, -1)

        return torch.concat((out, data_ori), dim=2)

    def sentence_landmarks_projection(self, x, p):

        projection = torch.bmm(torch.bmm(p.transpose(1, 2), x), p)
        return projection

    def sparsification_norm(self, x, threshold):

        threshold = torch.quantile(x, threshold, dim=2, keepdim=True)
        x_masked = torch.where(x > threshold, x, torch.zeros_like(x))
        row_sums = x_masked.sum(dim=2, keepdim=True)

        row_sums = torch.where(row_sums != 0, row_sums, torch.ones_like(row_sums))

        x_normalized = x_masked / row_sums

        return x_normalized

    def constructing_graphs_among_landmarks_using_sparsi_matrix(self, data, distance_weight):
        b, k, n = data.shape
        outer_product = torch.einsum('bik,bjl->bijkl', data, data)
        distance_weight_expanded = distance_weight.unsqueeze(0).unsqueeze(0).repeat(b, k, k, 1, 1)
        Q = outer_product * distance_weight_expanded
        row_maxes = Q.max(dim=-2)[0]
        col_maxes = Q.max(dim=-1)[0]
        sum_maxes = row_maxes.sum(dim=-1) + col_maxes.sum(dim=-1)
        return sum_maxes  # b*k*K

    def forward(self, x):
        torch.autograd.set_detect_anomaly(True)
        context = x[0]
        mask = x[2]
        encoder_out = self.pretrained_model(context, attention_mask=mask).last_hidden_state
        if not self.landmarks_initialized:
            self.initialize_landmarks()

        distance_weight = self.compute_distance_weight()

        distance_matrix = self.distance_matrix(encoder_out, self.landmarks)

        similarity_matrix = self.probability_matrix(distance_matrix, mask)
        probability_matrix_row = similarity_matrix / (similarity_matrix.sum(dim=2, keepdim=True) + 1e-6)
        probability_matrix_col = probability_matrix_row / (probability_matrix_row.sum(dim=1, keepdim=True) + 1e-6)

        # b*k*d
        out_k_d = torch.bmm(probability_matrix_col.transpose(1, 2), encoder_out)

        similarity_matrix = similarity_matrix.transpose(1, 2)

        sparsi_matrix = self.sparsification_norm(similarity_matrix, self.st)
        lan_Graph = self.constructing_graphs_among_landmarks_using_sparsi_matrix(sparsi_matrix,
                                                                                 distance_weight.unsqueeze(0))

        landmarks_after_graph = self.gcn_batched(out_k_d, lan_Graph)
        landmarks_after_graph = landmarks_after_graph.permute(0, 2, 1)
        landmarks_after_graph = self.reducer(landmarks_after_graph)

        landmarks_after_graph = landmarks_after_graph.reshape(landmarks_after_graph.size(0), -1)
        landmarks_after_graph = F.gelu(self.fc1(landmarks_after_graph))
        landmarks_after_graph = self.dropout(landmarks_after_graph)
        landmarks_after_graph = F.gelu(self.fc2(landmarks_after_graph))
        landmarks_after_graph = self.dropout(landmarks_after_graph)
        landmarks_after_graph = self.fc11(landmarks_after_graph)
        landmarks_after_graph = self.dropout3(landmarks_after_graph)

        res = landmarks_after_graph

        return res
