import os
import sys
import math
import random
from typing import Tuple, List
import wandb
from tqdm import tqdm
from sklearn.metrics import f1_score, confusion_matrix
import numpy as np
from multiprocessing import Pool
import torch
from torch import nn
import torch.nn.functional as F 
import torch_geometric.transforms as T
from torch_geometric.nn import DataParallel
from torch_geometric.loader import DataListLoader
from torch_geometric.data import Batch
from torch_geometric.nn.inits import glorot, ones
from torch_geometric.utils import to_dense_adj
from torch_scatter import scatter
import evaluate
from preprocess import Dataset
from GPT_GNN.config import pos_size, tokenizer, roberta, in_dim, edge_dim, gen_step_size, text_embed_size
from GPT_GNN.conv.GAT import GAT


class SmallGraphError(Exception):
    pass


class PositionalEncoding(nn.Module):
    # excerpted from https://pytorch.org/tutorials/beginner/transformer_tutorial.html
    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 10000):
        super().__init__()
        self.max_len = max_len
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, t: int, node_embeds: torch.Tensor) -> torch.Tensor:
        """
        Arguments:
            x: Tensor, shape ``[seq_len, batch_size, embedding_dim]``
        """
        x, y = node_embeds.unsqueeze(1)[:t, :, :], node_embeds.unsqueeze(1)[t:, :, :]
        s = x.shape[0]
        x = x + self.pe[:s]
        x = torch.cat([x, y], dim=0)
        return self.dropout(x).squeeze(1)


class Attention(nn.Module):
    def __init__(self, hidden_dim):
        super(Attention, self).__init__()
        self.hidden_dim = hidden_dim
        self.w = nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.tanh = nn.Tanh()

    def forward(self, h, embeds):
        att = h * embeds
        att = self.w(att)
        att = self.tanh(att)
        return att.sum(dim=-1) / math.sqrt(self.hidden_dim)
        

class Pointer(nn.Module):

    def __init__(self, hidden_dim):
        super(Pointer, self).__init__()
        self.encoder = nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.attention = Attention(hidden_dim)
        self.dropout = nn.Dropout(0.2)
        
    def forward(self, past, next):
        h0 = self.encoder(past)
        h = h0.sum(dim=0).unsqueeze(0)
        out = self.attention(h, next)
        return out


class GNN(nn.Module):
    def __init__(self):
        super(GNN, self).__init__()
        self.edge_filter_threshold = 1.0  # 0.3
        self.start_node = torch.zeros(in_dim).float() + 0.0001
        self.start_pos = torch.zeros(pos_size).float() + 0.0001
        self.num_attn_heads = 16
        self.hidden_dim =  text_embed_size # self.num_attn_heads ** 2
        self.dropout_rate = 0.2
        self.step_size = gen_step_size
        self.dropout = nn.Dropout(self.dropout_rate)
        self.in_proj = nn.Linear(in_features=in_dim, out_features=self.hidden_dim)
        self.dense = nn.Linear(in_features=self.hidden_dim, out_features=self.hidden_dim)
        self.pos_encoder = PositionalEncoding(self.hidden_dim)
        self.conv = GAT(in_channels=self.hidden_dim, 
                        hidden_channels=self.hidden_dim,  # int(self.hidden_dim/self.num_attn_heads), 
                        out_channels=self.hidden_dim, 
                        num_layers=2, 
                        v2=True, 
                        dropout=self.dropout_rate, 
                        act="LeakyReLU", 
                        act_first=False,
                        jk="lstm",
                        edge_dim=edge_dim,
                        heads=self.num_attn_heads,
                        add_self_loops=True)
        self.next_node_pointer = Pointer(self.hidden_dim)
        self.roberta = roberta
        self.hidden_proj = nn.Linear(self.hidden_dim, self.hidden_dim)
        self.segment_proj = nn.Linear(self.hidden_dim, 1)
        self.w_1 = nn.Linear(self.hidden_dim, self.hidden_dim, bias=False) # nn.Parameter(torch.Tensor(self.hidden_dim))
        self.w_2 = nn.Linear(self.hidden_dim, self.hidden_dim, bias=False) # nn.Parameter(torch.Tensor(self.hidden_dim))
        self.w_3 = nn.Linear(self.hidden_dim, self.hidden_dim, bias=False)
        self.w_4 = nn.Linear(self.hidden_dim, self.hidden_dim, bias=False)
        self.loss_alpha = 0.01
        self.loss_gamma = 2.
        self.attn = nn.MultiheadAttention(self.hidden_dim, num_heads=self.hidden_dim, dropout=0.0, batch_first=True)
        self.cosine = nn.CosineSimilarity(dim=1)
        self.tanh = nn.Tanh()
        # self.reset_parameters()
    
    def reset_parameters(self):
        # ones(self.w_i)
        # ones(self.w_j)
        pass

    def prepare_node_features(self, node_feats, mask):
        token_ids = node_feats
        token_embeds = self.embed_tokens(token_ids, mask)
        return token_ids, token_embeds

    def get_cls_embeds(self, token_ids, mask):
        token_embeddings = self.roberta(token_ids.long(), attention_mask=mask)
        token_embeddings = token_embeddings.pooler_output.squeeze(1)
        return token_embeddings

    def embed_tokens(self, token_ids, mask):
        token_embeddings = self.roberta(token_ids.long(), attention_mask=mask)
        token_embeddings = token_embeddings.last_hidden_state[:, 1, :].squeeze(1)
        return token_embeddings
    
    def add_global_node(self, node_feature, pos, comm, block, segment, edge_index, edge_label, edge_attr):
        device = node_feature.device
        # define global node and add it to current nodes
        global_node = self.start_node.unsqueeze(0).to(device)
        node_features = torch.cat([global_node, node_feature], dim=0)
        global_pos = self.start_pos.unsqueeze(0).to(device)
        pos = torch.cat([global_pos, pos], dim=0)
        comm = torch.cat([torch.Tensor([0]).to(device), comm+1.], dim=0).long()
        block = torch.cat([torch.Tensor([0]).to(device), block+1.], dim=0).long()
        segment = torch.cat([torch.Tensor([0]).to(device), segment], dim=0).long()
        edge_indexes = [ei+1 for ei in edge_index]
        edge_attrs = edge_attr  # + 1
        edge_labels = edge_label
        return node_features, pos, comm, block, segment, edge_indexes, edge_labels, edge_attrs

    def attend(self, tx, ty, n_out):
        return (tx * ty) / math.sqrt(n_out)

    def node_loss_ce(self, preds, true):
        loss = F.cross_entropy(preds, true)
        return loss
    
    def node_loss_emd(self, preds, true):
        loss = torch.mean(torch.square(torch.cumsum(true, dim=-1) - torch.cumsum(torch.softmax(preds, dim=-1), dim=-1)), dim=-1)
        return loss
    
    def node_loss_nll(self, preds, true):
        return -torch.log(torch.softmax(preds, dim=0)[true]) * 1./(1.+math.log(preds.shape[0]))
    
    def node_loss_triplet(self, prev, next):
        n = next.shape[0]-1
        # if only one node is left, that one is guaranteed to the next node, so the triplet loss doesn't make sense
        if n == 0: return torch.tensor(0.0, requires_grad=True).to(prev.device)
        anchor = prev.repeat(n, 1)
        pos = next[0, :].unsqueeze(0).repeat(n, 1)
        neg = next[1:, :]
        return F.triplet_margin_loss(anchor, pos, neg, margin=1.0)
    
    def node_loss_contrastive(self, anchor, pos, all):
        tau = 0.1
        numerator = torch.exp(F.cosine_similarity(anchor, pos, dim=0) / tau)
        denominator = torch.sum(torch.exp(F.cosine_similarity(anchor, all, dim=1) / tau), dim=0)
        loss = - torch.log(numerator / denominator)
        return loss.unsqueeze(0)

    def edge_loss(self, preds, true):
        return F.binary_cross_entropy(torch.sigmoid(preds), true).unsqueeze(0)
    
    def block_loss(self, preds, true):
        return F.binary_cross_entropy(torch.sigmoid(preds), true).unsqueeze(0)
    
    def block_loss_bce(self, preds, true):
        return F.binary_cross_entropy(torch.sigmoid(preds), true).unsqueeze(0)
    
    @staticmethod
    def block_loss_contrastive(numerator):
        temp = 0.1
        num_segs = numerator.size(0)
        if num_segs < 1: return torch.tensor([0.0], requires_grad=True).to(numerator.device)
        l = torch.ones(num_segs)
        l = l*num_segs
        l = l.long().to(numerator.device)
        nmr = torch.repeat_interleave(numerator, l, dim=0)
        denominator = numerator.repeat(num_segs, 1)
        # (torch.min(nmr), torch.max(nmr))
        denominator = F.cosine_similarity(nmr, denominator)
        denominator = torch.exp(denominator/temp).reshape(num_segs, num_segs).sum(dim=1)
        one = torch.exp(torch.ones(1).to(numerator.device)/temp)
        fraction = one / torch.sum(denominator-one, dim=0)
        return torch.clamp(torch.nan_to_num(-torch.log(fraction), nan=0.0), min=0.0)
    
    def generate_node_ordering(self, node_embeds, num_nodes):
        d = node_embeds.device

        prev_node_indices = [[j for j in range(0, i)] + [0 for _ in range(i, num_nodes-1)] for i in range(1, num_nodes)]
        prev_node_indices = torch.Tensor(prev_node_indices).long().to(d)
        prev_nodes = node_embeds[torch.arange(prev_node_indices.size(0)).unsqueeze(1), prev_node_indices]

        prev_node_mask = [[0.0 for _ in range(0, i)] + [-100.0 for _ in range(i, num_nodes-1)] for i in range(1, num_nodes)]
        prev_node_mask = torch.Tensor(prev_node_mask).to(d)
        prev_node_mask = torch.repeat_interleave(prev_node_mask, self.num_attn_heads, dim=0)

        next_node_indices = [[0 for _ in range(0, i-1)] + [j for j in range(i+1, num_nodes)] for i in range(1, num_nodes)]
        next_node_indices = torch.Tensor(next_node_indices).long().to(d)
        next_nodes =  node_embeds[torch.arange(next_node_indices.size(0)).unsqueeze(1), next_node_indices]

        next_node_mask = [[-100.0 for _ in range(0, i-1)] + [0.0 for _ in range(i+1, num_nodes)] for i in range(1, num_nodes)]
        next_node_mask = torch.Tensor(next_node_mask).to(d)
        next_node_mask = torch.repeat_interleave(next_node_mask, self.num_attn_heads, dim=0)

        past_node_indices = [i for i in range(0, num_nodes-1)]
        past_node_indices = torch.Tensor(past_node_indices).long().to(d)
        past_node = node_embeds[torch.arange(past_node_indices.size(0)), past_node_indices]

        cur_node_indices = [i for i in range(1, num_nodes)]
        cur_node_indices = torch.Tensor(cur_node_indices).long().to(d)
        cur_node = node_embeds[torch.arange(cur_node_indices.size(0)), cur_node_indices]

        return prev_nodes, prev_node_mask.unsqueeze(2), next_nodes, next_node_mask.unsqueeze(1), past_node.unsqueeze(1), cur_node.unsqueeze(1)
    
    def select_edge_indexes(self, edge_indexes, i, idx):
        x = edge_indexes[:, idx].squeeze()
        return x

    def select_edge_attrs(self, edge_attrs, i, idx):
        x = edge_attrs[idx, :].squeeze()
        return x

    def generate_adj_vectors(self, num_nodes, edge_indexes, rand_start, rand_end):
        # try:
        device = edge_indexes.device
        steps = torch.arange(rand_start, rand_end, step=self.step_size).unsqueeze(1).to(device)
        shp = steps.shape[0]
        from_i = edge_indexes[0, :].unsqueeze(0).repeat(shp, 1)  == steps
        from_i_adj = edge_indexes[1, :].unsqueeze(0).repeat(shp, 1) < steps
        from_i_adj = torch.logical_and(from_i, from_i_adj)
        to_i = edge_indexes[1, :].unsqueeze(0).repeat(shp, 1) == steps
        to_i_adj = edge_indexes[0, :].unsqueeze(0).repeat(shp, 1) < steps
        to_i_adj = torch.logical_and(to_i, to_i_adj)
        i_adj = torch.logical_or(from_i_adj, to_i_adj)
        if len(i_adj.size()) == 0: raise SmallGraphError
        step = 0
        true_next_edge = []
        for i in range(rand_start, rand_end, self.step_size):
            if step >= i_adj.shape[0] or i_adj[step].shape[-1] == 0: raise SmallGraphError
            # include the 'next node' in the output vector, but filter it during loss calculation
            adj_vec = torch.zeros(i+1).to(device) 
            adj_vec[edge_indexes[:, i_adj[step]][0]] = 1.
            adj_vec[edge_indexes[:, i_adj[step]][1]] = 1.
            # no self loops
            adj_vec[i] = 0.
            true_next_edge.append(adj_vec)
            step += 1
        del i_adj, to_i, from_i, to_i_adj, from_i_adj
        return true_next_edge
    
    def generate_edge_ordering(self, e_indexes, e_labels, e_attrs, num_nodes, rand_start, rand_end):
        device = e_indexes.device
        # filter down to short edges
        f = self.edge_filter_threshold
        short_edges = torch.logical_and(torch.logical_and((torch.abs(e_attrs[:, 0]) < f), (torch.abs(e_attrs[:, 1]) < f)), torch.logical_and((torch.abs(e_attrs[:, 2]) < f), (torch.abs(e_attrs[:, 3]) < f)))
        edge_indexes = e_indexes[:, short_edges]
        edge_labels = e_labels[short_edges]
        edge_attrs = e_attrs[short_edges, :]
        # make the edge matrices undirected (symmetric)
        edge_indexes = torch.cat([edge_indexes, edge_indexes[torch.Tensor([1, 0]).long().to(device)]], dim=1)
        edge_labels = torch.cat([edge_labels, edge_labels], dim=0)
        edge_attrs = torch.cat([edge_attrs, edge_attrs], dim=0)
        # generate edge ordering
        device = edge_indexes.device
        steps = torch.arange(rand_start, rand_end, step=self.step_size).unsqueeze(1).to(device)
        shp = steps.shape[0]
        edges_forward = edge_indexes[0, :].unsqueeze(0).repeat(shp, 1) < steps
        edges_backward = edge_indexes[1, :].unsqueeze(0).repeat(shp, 1) < steps
        edges = torch.logical_and(edges_forward, edges_backward)
        edge_index = [edge_indexes[:, e] for e in edges]
        edge_attr = [edge_attrs[e, :] for e in edges]
        edge_label = [edge_labels[e] for e in edges]
        del edges, edges_forward, edges_backward
        return edge_index, edge_label, edge_attr
    
    def shift_indexes_backward(self, edge_indexes):
        num_nodes, _, num_edges = edge_indexes.shape
        idx = torch.arange(0, num_nodes).unsqueeze(1).unsqueeze(1)
        idx = idx.repeat(1, 2, num_edges).to(edge_indexes.device)
        return edge_indexes - idx
    
    def comm_loss_attn(self, edge_indexes, edge_attention_weights, comms):
        if torch.isnan(edge_attention_weights).any(): print('attention weights')
        if torch.isnan(comms).any(): print('comms')
        mean_weights = edge_attention_weights.mean(dim=-1)
        comms_i = comms[edge_indexes[0]]
        comms_j = comms[edge_indexes[1]]
        diff_comms = (comms_i != comms_j).float()
        penalty = mean_weights - diff_comms
        penalty = torch.pow(penalty, 2)
        return penalty.mean()

    def comm_loss_bce(self, edge_indexes, preds, comms):
        comms_i = comms[edge_indexes[0]]
        comms_j = comms[edge_indexes[1]]
        same_comms = (comms_i == comms_j).float()
        preds_mean = preds.mean(dim=-1)  # mean across attention heads
        return F.binary_cross_entropy(torch.sigmoid(preds_mean), same_comms).unsqueeze(0)
    
    @staticmethod
    def contrastive_loss(preds, trues, index):
        temp = 0.1
        neg_sample_size = 5
        if index.shape[0] < 2: return torch.tensor([0.0], requires_grad=True).to(preds.device)
        preds = scatter(preds, index, dim=0, reduce="sum")
        trues = scatter(trues, index, dim=0, reduce="sum")
        num_nodes = preds.shape[0]
        if num_nodes < neg_sample_size+1: return torch.tensor([0.0], requires_grad=True).to(preds.device)
        rand_indices = torch.randperm(num_nodes).unsqueeze(0).repeat(num_nodes, 1)
        selection = torch.arange(neg_sample_size).unsqueeze(0).repeat(num_nodes, 1)
        shift = torch.arange(num_nodes).unsqueeze(1)
        selection = selection + shift  # [[0, 1, 2, 3, 4], [1, 2, 3, 4, 5], ...]
        selection[selection >= num_nodes] = selection[selection >= num_nodes] - num_nodes  # out of bound indices will loop over
        rand_indices = torch.gather(rand_indices, 1, selection)
        truess = trues[rand_indices].reshape(neg_sample_size*num_nodes, -1)
        predss = preds.repeat_interleave((torch.ones(num_nodes).to(preds.device)*neg_sample_size).long(), dim=0)
        dist = F.cosine_similarity(predss, truess)/temp
        ex = torch.exp(dist).reshape(neg_sample_size, num_nodes, -1)
        denominator = ex.sum(dim=0)
        numerator = torch.exp(F.cosine_similarity(preds, trues)/temp)
        return torch.clamp(torch.nan_to_num(-torch.log(numerator / denominator), nan=0.0), min=0.0)
    
    @staticmethod
    def mulco_loss(h_text, h_gnn, blocks):
        h_text_hat = h_text.detach()
        h_gnn_hat = h_gnn.detach()
        loss1 = GNN.contrastive_loss(h_text, h_gnn_hat, blocks)
        loss2 = GNN.contrastive_loss(h_gnn, h_text_hat, blocks)
        return (loss1 + loss2).mean()
    
    @staticmethod
    def get_segment_embeds(node_embeds, pred_segment_boundaries, gold_segment_boundaries):
        block_loss = NodeClassifier.block_loss(pred_segment_boundaries, gold_segment_boundaries.unsqueeze(1).float())
        segment_bs = (torch.sigmoid(pred_segment_boundaries.squeeze(1)) > 0.5).long()
        segment_bs[0] = 1
        num_segments = segment_bs.sum()
        segment_b_indices = segment_bs.nonzero().squeeze()
        # if only one segment was detected, then add some dimenstionality to the indices
        if len(segment_b_indices.size()) == 0: segment_b_indices = segment_b_indices.unsqueeze(0)
        segment_b_indices_shifted = torch.cat([segment_b_indices[1:], torch.Tensor([node_embeds.shape[0]]).to(node_embeds.device)], dim=0)
        segment_sizes = (segment_b_indices_shifted - segment_b_indices).long()
        segment_seq = torch.arange(num_segments).to(num_segments.device)
        segment_index = torch.repeat_interleave(segment_seq, segment_sizes, dim=0)
        assert node_embeds.shape[0] == segment_index.shape[0]
        segment_embeds = scatter(node_embeds, segment_index, dim=0, reduce="sum")
        return segment_embeds, segment_index, segment_sizes, segment_bs, block_loss

    def forward_once(self, t, x, e_indexes, e_labels, e_attrs, pos, comms, batch_apply=False):
        h, (ei, edge_attention_weights) = self.conv(x[:t], pos[:t], comms[:t], e_indexes, edge_type=e_labels.long(), edge_attr=e_attrs, return_attention_weights=True)
        h = self.hidden_proj(h)
        h_hat, _ = self.attn(h, h, h, need_weights=False)
        # residual connection
        h = h + x[:t]
        h_t_minus_1 = self.tanh(self.w_1(h[0:-1, :] if batch_apply else h[t-1, :].unsqueeze(0)))
        h_t_to_T = self.tanh(self.w_2(x[1:, :] if batch_apply else x[t:, :]))
        h_0_to_t_minus_1 = self.tanh(self.w_3(h[0:-1, :] if batch_apply else h[:t, :]))
        h_t = self.tanh(self.w_4(x[1:, :] if batch_apply else x[t, :].unsqueeze(0)))
        if batch_apply:
            next_node_selection = None
            next_edge_weights = None
            true_idx = torch.Tensor([0])

        else:
            num_nodes = h_t_to_T.shape[0]
            device = h_t_to_T.device
            perm = torch.randperm(num_nodes)
            true_idx = torch.Tensor([0]) 
            true_idx = true_idx.to(device).float()
            next_node_selection = self.next_node_pointer(h_0_to_t_minus_1, h_t_to_T)
            i = h_t 
            j = h_0_to_t_minus_1 
            next_edge_weights = (i * j).sum(dim=1).unsqueeze(0)/math.sqrt(self.hidden_dim)
            next_edge_weights = self.dropout(next_edge_weights).squeeze(0)
        return h, h_hat, next_node_selection, true_idx, next_edge_weights, edge_attention_weights

    # Input: Nodes (x) and edges (edge index);
    # Output1: the final representation of the global node
    # Output2: the predicted adjacency vectors of the all nodes
    # Output3: the predicted next tokens as selected by the pointer network
    def forward(self, data):
        # device = "cpu"
        # data = data[0].to(device)
        x, edge_index, edge_label, edge_attr, pos, mask, comms, blocks, segments, y = data.x, data.edge_index, data.edge_label, data.edge_attr, data.pos, data.mask, data.comm, data.block, data.segment, data.y
        if edge_attr is None or len(edge_attr) == 0: 
            raise SmallGraphError
        token_idx, node_features = self.prepare_node_features(x, mask)
        device = x.device
        num_nodes = node_features.shape[0]
        rand_size = random.randint(1, min([20, num_nodes-1]))
        rand_start = random.randint(0, num_nodes-rand_size-1)
        rand_end = rand_start + rand_size
        edge_indexes, edge_labels, edge_attrs = self.generate_edge_ordering(edge_index, edge_label, edge_attr, num_nodes, rand_start, rand_end)
        true_next_edges = self.generate_adj_vectors(num_nodes, edge_index, rand_start, rand_end)
        node_features, pos, comms, blocks, segments, e_indexes, e_labels, e_attrs = self.add_global_node(node_features, pos, comms, blocks, segments, edge_indexes, edge_labels, edge_attrs)
        # generate node embeddings
        # num_edges = edge_indexes.shape[1]
        # if hasattr(self.conv, 'jk') and self.conv.jk.lstm is not None: self.conv.jk.lstm.flatten_parameters()
        node_losses, edge_losses, comm_losses, block_losses, segment_losses, mulco_losses = [], [], [], [], [], []
        step = 0
        for idx in range(rand_start+1, rand_end+1, self.step_size):
            pos_masked = torch.cat([pos[:idx, : ], torch.zeros_like(pos[idx:, :]).to(pos.device)], dim=0)
            node_embeds, h_hat, next_node_selection, true_node_selection, next_edge_weights, edge_attention_weights = self.forward_once(idx, node_features, e_indexes[step], e_labels[step], e_attrs[step], pos_masked, comms, batch_apply=False)
            node_loss = self.node_loss_nll(next_node_selection, true_node_selection.squeeze().long())
            edge_loss = self.edge_loss(next_edge_weights, true_next_edges[step][:idx])
            segment_preds = self.segment_proj(node_embeds[1:idx, :])
            block_loss = self.block_loss_bce(segment_preds, segments[1:idx].unsqueeze(1).float())
            segment_loss =  torch.tensor([0.0], requires_grad=True).to(node_embeds.device)
            mulco_loss = GNN.mulco_loss(node_features[:idx], h_hat[:idx], blocks[:idx])
            node_losses.append(node_loss.unsqueeze(0))
            edge_losses.append(edge_loss)
            comm_losses.append(torch.Tensor([0]).to(x.device))
            block_losses.append(block_loss.unsqueeze(0))
            segment_losses.append(segment_loss.unsqueeze(0))
            mulco_losses.append(mulco_loss.unsqueeze(0))
            step += 1
        node_losses = torch.cat(node_losses, dim=0)
        edge_losses = torch.cat(edge_losses, dim=0)
        # The global node doesn't have a meaningful community loss-setting it to zero.
        comm_losses = torch.cat(comm_losses, dim=0).nan_to_num(nan=0.0)
        block_losses = torch.cat(block_losses, dim=0)
        segment_losses = torch.cat(segment_losses, dim=0)
        mulco_losses = torch.cat(mulco_losses, dim=0)
        return node_losses.sum(), edge_losses.sum(), comm_losses.sum(), segment_losses.sum(), block_loss.sum(), mulco_losses


class MLMGNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.hidden_dim = text_embed_size
        self.dropout_rate = 0.2
        self.num_attention_heads = 16
        self.conv = GAT(in_channels=self.hidden_dim, 
                        hidden_channels=self.hidden_dim,  # int(self.hidden_dim/self.num_attn_heads), 
                        out_channels=self.hidden_dim, 
                        num_layers=2, 
                        v2=True, 
                        dropout=0.0, 
                        act="LeakyReLU", 
                        act_first=False,
                        jk="lstm",
                        edge_dim=edge_dim,
                        heads=self.num_attention_heads,
                        add_self_loops=True)
        self.attn = nn.MultiheadAttention(self.hidden_dim, num_heads=12, dropout=self.dropout_rate, batch_first=True)
        self.cosine = nn.CosineSimilarity(dim=1)
        self.roberta = roberta
        self.segment_proj = nn.Linear(self.hidden_dim, 1)
        self.tanh = nn.Tanh()

    def prepare_node_features(self, node_feats, mask):
        token_ids = node_feats
        token_embeds = self.embed_tokens(token_ids, mask)
        return token_ids, token_embeds
    
    def embed_tokens(self, token_ids, mask):
        token_embeddings = self.roberta(token_ids.long(), attention_mask=mask)
        token_embeddings = token_embeddings.pooler_output   #.last_hidden_state[:, 1, :].squeeze(1)
        return token_embeddings
    
    def block_loss_bce(self, preds, true):
        return F.binary_cross_entropy(torch.sigmoid(preds), true).unsqueeze(0)
    
    # t, x, e_indexes, e_labels, e_attrs, pos, comms, batch_apply=False
    def forward_once(self, idx, node_features, edge_index, edge_label, edge_attr, pos, comms, batch_apply=True):
        h, (ei, edge_attention_weights) = self.conv(node_features, pos, comms, edge_index, edge_type=edge_label.long(), edge_attr=edge_attr, return_attention_weights=True)
        h_hat, _ = self.attn(h, h, h, need_weights=False)
        # residual connection
        h = h + node_features
        return h, h_hat, None, None, None, None

    def forward(self, data):
        x, edge_index, edge_label, edge_attr, pos, mask, comms, blocks, segments = data.x, data.edge_index, data.edge_label, data.edge_attr, data.pos, data.mask, data.comm, data.block, data.segment
        token_idx, node_features = self.prepare_node_features(x, mask)
        h, h_hat, _, _, _, _ = self.forward_once(-1, node_features, edge_index, edge_label, edge_attr, pos, comms)
        segments_preds = self.segment_proj(h)
        block_loss = self.block_loss_bce(segments_preds, segments.unsqueeze(1).float())
        segment_embeds, segment_index, segment_sizes, segment_bs, block_loss = GNN.get_segment_embeds(h, self.segment_proj(h), segments)
        segment_loss = GNN.block_loss_contrastive(segment_embeds)
        mulco_loss = GNN.mulco_loss(node_features, h_hat, blocks)
        return None, None, None, block_loss, segment_loss, mulco_loss
    

class SeqClassificationHead(nn.Module):
    """Head for sentence-level classification tasks."""

    def __init__(self, hidden_size, dropout_rate, num_classes):
        super(SeqClassificationHead, self).__init__()
        self.dense = nn.Linear(hidden_size, hidden_size)
        self.dropout = nn.Dropout(dropout_rate)
        self.out_proj = nn.Linear(hidden_size, num_classes)

    def forward(self, x):
        # x = features[:, 0, :]  # take <s> token (equiv. to [CLS])
        x = self.dropout(x)
        x = self.dense(x)
        x = torch.tanh(x)
        x = self.dropout(x)
        x = self.out_proj(x)
        return x
    

class GraphClassificationHead(nn.Module):
    def __init__(self, hidden_size, dropout_rate, num_classes):
        super(GraphClassificationHead, self).__init__()
        self.rnn = nn.LSTM(input_size=hidden_size, hidden_size=hidden_size//4, num_layers=2, batch_first=True, dropout=dropout_rate, bidirectional=False)
        # self.rnn = nn.Linear(hidden_size, hidden_size)
        self.dropout = nn.Dropout(dropout_rate)
        self.out_proj = nn.Linear(hidden_size//4, num_classes)

    def forward(self, x):
        x, _ = self.rnn(x)
        x = torch.tanh(x)
        x = self.dropout(x)
        x = self.out_proj(x)
        return x


class NodeClassifier(nn.Module):
    def __init__(self, num_classes, gnn, gnn_type):
        super(NodeClassifier, self).__init__()
        self.num_classes = num_classes
        print('num classes are', num_classes)
        self.num_segment_classes = 4
        self.gnn = gnn
        self.gnn_type = gnn_type
        self.hidden_dim = text_embed_size # self.gnn.conv.out_channels if self.gnn is not None else 0
        self.dropout_rate = 0.5
        self.dropout = nn.Dropout(self.dropout_rate)
        self.softmax = nn.Softmax(dim=1)
        self.beam_projection = nn.Linear(self.num_classes, in_dim)
        self.emission_projection = nn.Linear(self.hidden_dim, self.num_classes)
        self.output_projection = nn.Linear(self.num_classes, self.num_classes, bias=False)
        self.text_classifier = SeqClassificationHead(self.hidden_dim, self.dropout_rate, self.num_classes)
        self.graph_classifier = GraphClassificationHead(self.hidden_dim, self.dropout_rate, self.num_classes)
        self.iterations = 3
        self.beam_size = 2
        self.tanh = nn.Tanh()
        
    def init_transition_params(self):
        i_counts = nn.Parameter(torch.zeros(self.num_classes, self.num_classes).float(), requires_grad=False)
        j_given_i_counts = nn.Parameter(torch.zeros(self.num_classes, self.num_classes).float(), requires_grad=False)
        transition_probs = nn.Parameter(torch.zeros(self.num_classes, self.num_classes).float(), requires_grad=False)
        return i_counts, j_given_i_counts, transition_probs

    def update_transition_probs(self, y):
        device = y.device 
        ys = torch.cat([torch.zeros(1).to(device), y], dim=0)
        for i in range(self.num_classes):
            for j in range(self.num_classes):
                i_match = (ys==i)[:-1]
                j_match = (ys==j)[1:]
                ij_match = torch.logical_and(i_match, j_match)
                self.i_counts[i, j] = self.i_counts[i, j] + i_match.sum().float()
                self.j_given_i_counts[i, j] = self.j_given_i_counts[i, j] + ij_match.float().sum()
        self.transition_probs = (self.j_given_i_counts + 1.) / (self.i_counts + 1.)
        return self.transition_probs
    
    @staticmethod
    def sync_transition_probs(i, ji, num_classes):
        ji_agg = ji.reshape(-1, num_classes, num_classes).sum(dim=0)
        i_agg = i.reshape(-1, num_classes, num_classes).sum(dim=0)
        return (ji_agg + 1.) / (i_agg + 1.)

    # Adapted from https://github.com/napsternxg/pytorch-practice/blob/master/Viterbi%20decoding%20and%20CRF.ipynb
    @staticmethod
    def viterbi_decoding_torch(emissions, transitions, num_classes):
        device = emissions.device
        transitions = torch.softmax(transitions, dim=1)
        scores = torch.zeros(emissions.size(1)).to(device)
        back_pointers = torch.zeros(emissions.size()).int().to(device)
        # max_scoress = torch.zeros(emissions.size()).float().to(device)
        scores = scores + emissions[0]
        # Generate most likely scores and paths for each step in sequence
        for i in range(1, emissions.size(0)):
            scores_with_transitions = scores.unsqueeze(1).expand_as(transitions) + transitions
            max_score, back_pointers[i] = torch.max(scores_with_transitions, 0)
            # max_scoress[i] = max_scores
            scores = emissions[i] + max_score
        # Generate the most likely path
        viterbi = [torch.argmax(scores).unsqueeze(0)]
        # final_scores = [torch.max(scores).unsqueeze(0)]
        # back_pointers = back_pointers.numpy()
        # idx = 0
        bps = back_pointers[1:].flip([0])
        # ms = max_scoress[1:].flip([0])
        for bp in bps:
            viterbi.append(bp[viterbi[-1].item()].unsqueeze(0))
            # final_scores.append(ms[idx][viterbi[-1].item()].unsqueeze(0))
            # idx += 1
        # final_scores.append(max_scoress[viterbi[-1].item()].unsqueeze(0))
        # final_scores.reverse()
        viterbi_score = torch.max(scores)
        viterbi.reverse()
        viterbi = torch.cat(viterbi, dim=0)
        viterbi = nn.Parameter(F.one_hot(viterbi, num_classes).float(), requires_grad=True)
        return viterbi_score, viterbi      

    def forward(self, data):
        # import sys
        # data = data[0].to(torch.device("cuda:0"))
        x, edge_index, edge_label, edge_attr, pos, mask, comms, y, blocks, segments = data.x, data.edge_index, data.edge_label, data.edge_attr, data.pos, data.mask, data.comm, data.y, data.block, data.segment
        token_idx, node_features = self.gnn.prepare_node_features(x, mask)
        num_nodes = node_features.shape[0]
        device = node_features.device
        # add the global node
        # node_features, pos, comms, edge_indexes, edge_attrs = self.gnn.add_global_node(node_features, pos, comms, edge_index, edge_attr)
        node_embeds, h_hat, _, _, _, _ = self.gnn.forward_once(num_nodes, node_features, edge_index, edge_label, edge_attr, pos, comms, batch_apply=True)
        segment_embeds, segment_index, segment_sizes, segment_bs, block_loss = GNN.get_segment_embeds(node_embeds, self.gnn.segment_proj(node_embeds), segments)
        graph_preds = self.graph_classifier(node_embeds)
        node_loss = NodeClassifier.node_loss(graph_preds, y, segment_sizes, segment_loss=False)
        segment_loss = GNN.block_loss_contrastive(segment_embeds)
        mulco_loss = GNN.mulco_loss(node_features, h_hat, blocks)
        return self.text_classifier(node_features), graph_preds, node_features, h_hat, y, node_loss, mulco_loss, block_loss, segment_loss, segment_sizes, segment_bs  # self.output_projection(node_embeds), y
    
    @staticmethod
    def node_loss(preds, trues, pred_segment_sizes, segment_loss=False):
        if segment_loss:
            trues = torch.div(trues + 1.0, torch.Tensor([2.]).to(trues.device), rounding_mode="floor").long()
            preds = torch.repeat_interleave(preds, pred_segment_sizes, dim=0)
        return F.cross_entropy(preds, trues)
    
    @staticmethod
    def block_loss(preds, trues):
        return F.binary_cross_entropy(torch.sigmoid(preds), trues)


class LinkClassificationHead(nn.Module):
    def __init__(self, hidden_size, dropout_rate):
        super(LinkClassificationHead, self).__init__()
        self.dense1 = nn.Linear(hidden_size, hidden_size)
        self.dense2 = nn.Linear(hidden_size, hidden_size)
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, node_embeds, edge_index, num_nodes):
        preds = self.dense1(node_embeds) @ self.dense2(node_embeds).transpose(1, 0)
        trues = to_dense_adj(edge_index, max_num_nodes=num_nodes).squeeze(0)
        return preds, trues


class LinkDetector(nn.Module):
    def __init__(self, gnn, gnn_type):
        super(LinkDetector, self).__init__()
        self.gnn = gnn
        self.gnn_type = gnn_type
        self.hidden_dim = text_embed_size # self.gnn.conv.out_channels if self.gnn is not None else 0
        # self.transition_decoder = nn.LSTM(self.num_classes, self.num_classes, num_layers=2, bidirectional=False, batch_first=True, dropout=0.2)
        self.dropout_rate = 0.2
        self.dropout = nn.Dropout(self.dropout_rate)
        self.sigmoid = nn.Sigmoid()
        self.bce_loss = nn.BCELoss()
        self.link_classifier = LinkClassificationHead(self.hidden_dim, self.dropout_rate)
        self.transform = T.Compose([T.RandomLinkSplit()])

    def forward(self, data):
        # data = Batch.from_data_list(data[:8]).to("cuda")
        x, edge_index, edge_label, edge_attr, pos, mask, comms, y, blocks, segments = data.x, data.edge_index, data.edge_label, data.edge_attr, data.pos, data.mask, data.comm, data.y, data.block, data.segment
        if len(edge_index.shape) < 2: raise SmallGraphError
        token_idx, node_features = self.gnn.prepare_node_features(x, mask)
        num_nodes = node_features.shape[0]
        device = node_features.device
        # add the global node
        # node_features, pos, comms, edge_indexes, edge_attrs = self.gnn.add_global_node(node_features, pos, comms, edge_index, edge_attr)
        node_embeds, h_hat, _, _, _, _ = self.gnn.forward_once(num_nodes, node_features, edge_index, edge_label, edge_attr, pos, comms, batch_apply=True)
        preds, trues = self.link_classifier(node_embeds, edge_index, num_nodes)
        # print(x.shape, node_embeds.shape, torch.min(edge_index), torch.max(edge_index), preds.shape, trues.shape)
        loss = self.edge_loss(preds, trues).unsqueeze(0)
        return preds, trues, loss

    def edge_loss(self, preds, trues):
        return self.bce_loss(self.sigmoid(preds), trues)
    