#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time : 2021/7/6 3:31
# @File : LRL.py
# @Software: PyCharm
import torch
import dgl
import dgl.function as fn
import numpy as np
from dgl.nn import RelGraphConv, GATConv, GraphConv, SGConv
# from TKG.utils import comp_deg_norm
import torch.nn as nn
import torch.nn.functional as F
from dgl.nn.functional import edge_softmax


# class LRL(nn.Module):
#     def __init__(self, graph, num_nodes, num_rels, time_length, time_idx, h_dim, out_dim, max_length=10, a_layer_num=2,
#                  d_layer_num=1, intra_k=5, inter_k=5, ori_graph=False, encoder='regcn', fuse='att', decoder='hgt',
#                  attn_drop=0.3, feat_drop=0.3, last=True, ori=True, norm=False, rel_update=True, relation_prediction=True, low_memory=True):
#         super(LRL, self).__init__()
#         self.g = graph
#         self.num_nodes = num_nodes
#         self.num_rels = num_rels * 2
#         self.time_rels = time_length
#         self.time_length = time_length  # 总的时间跨度
#         self.time_idx = time_idx        # 不同时间下节点在graph中的索引
#         self.h_dim = h_dim
#         self.out_dim = out_dim
#         self.a_layer_num = a_layer_num
#         self.d_layer_num = d_layer_num
#         self.intra_k = intra_k
#         self.inter_k = inter_k
#         self.ori_graph = ori_graph
#         self.en_embedding = None
#         self.max_length = max_length
#         self.relation_prediction = relation_prediction
#         self.low_memory = low_memory
#         self.encoder = encoder
#         self.fuse_g = fuse
#         self.decoder = decoder
#         self.attn_drop = attn_drop
#         self.feat_drop = feat_drop
#         self.last = last
#         self.ori = ori
#         self.norm = norm
#         self.rel_embedding = None

#         # 结构学习需要的参数
#         self.linear_l2 = nn.Linear(self.h_dim, self.h_dim, bias=True)
#         self.linear_s2 = nn.Linear(self.h_dim, self.h_dim, bias=True)
#         self.fuse_f2 = nn.Linear(self.h_dim, 1, bias=False)
#         self.linear_w1 = nn.Linear(self.h_dim, 1, bias=False)
#         self.linear_w2 = nn.Linear(self.h_dim, 1, bias=False)
#         self.linear_r = nn.Linear(self.h_dim, 1, bias=False)
#         self.linear_s1 = nn.Linear(self.h_dim, self.h_dim, bias=True)
#         self.linear_s2 = nn.Linear(self.h_dim, self.h_dim, bias=True)
#         # 考虑计算时间
#         self.linear_s3 = nn.Linear(self.h_dim, 1, bias=False)
#         self.linear_s4 = nn.Linear(self.h_dim, 1, bias=False)
#         self.linear_r = nn.Linear(self.h_dim, 1, bias=False)
#         self.gate = GatingMechanism(self.num_nodes, self.h_dim)
#         self.rel_update = rel_update

#         dim = 1
#         if self.last:
#             dim += 1
#         if self.ori:
#             dim += 1
#         self.linear_1 = nn.Linear(self.out_dim * dim, self.out_dim, bias=False)
#         self.rnn = nn.GRU(self.h_dim, self.h_dim, batch_first=True)
#         self.reset_parameters()
#         self.aggregator = None
#         # 初始化解码器
#         self.decoder_f = GNN(self.h_dim, self.h_dim, layer_num=self.d_layer_num, gnn=self.decoder,
#                              attn_drop=self.attn_drop, feat_drop=self.feat_drop)

#     def forward(self, data_list, node_id_new=None, device=None, mode='test'):
#         h = F.normalize(self.en_embedding(self.g.ndata['id']))
#         t_id = self.g.ndata['t_id']
#         # **********Structure Encoder*******
#         all_graph = data_list['all_graph'].to(device)
#         pre_all_id = data_list['pre_all_nid']
#         # norm = comp_deg_norm(all_graph)
#         all_graph.ndata.update({'norm': norm.view(-1, 1)})
#         all_graph.apply_edges(
#             lambda edges: {'norm': edges.dst['norm'] * edges.src['norm']})
#         global_re = self.aggregator(
#             all_graph, h[pre_all_id], self.rel_embedding.weight[0:self.num_rels * 2], 1)
#         h[pre_all_id] = F.normalize(global_re)
#         # ========================================
#         h_candidate = h.clone()
#         new_feature = h.clone()
#         # *******Latent Relations Learning******
#         sub_graph = data_list['sub_graph'].to(device)
#         pre_id = data_list['pre_e_nid']
#         t_index = t_id[pre_id]
#         new_feature = F.normalize(new_feature)
#         new_feature[pre_id] = F.normalize(self.generate_new_embedding(sub_graph, new_feature[pre_id],
#                                                                       t_index, self.intra_k, self.inter_k))
#         all_list = data_list['all_list']
#         all_length = data_list['all_length']
#         s_len, s_idx = all_length.sort(0, descending=True)
#         num_non_zero = len(torch.nonzero(s_len))
#         s_len_non_zero = s_len[:num_non_zero]
#         if len(s_len_non_zero) > 0:
#             r_input = new_feature[all_list[s_idx[0:num_non_zero]]]
#             packed_input = torch.nn.utils.rnn.pack_padded_sequence(
#                 r_input, s_len_non_zero, batch_first=True)
#             tt, s_h = self.rnn(packed_input)
#             new_feature[node_id_new[s_idx[0:num_non_zero]]
#                         ] = F.normalize(s_h.squeeze(0))
#             # candidate embedding
#             r_can_input = h_candidate[all_list[s_idx[0:num_non_zero]]]
#             packed_can_input = torch.nn.utils.rnn.pack_padded_sequence(
#                 r_can_input, s_len_non_zero, batch_first=True)
#             tt_can, s_can_h = self.rnn(packed_can_input)
#             h_candidate[node_id_new[s_idx[0:num_non_zero]]
#                         ] = F.normalize(s_can_h.squeeze(0))

#         h_candidate = h_candidate[node_id_new]
#         new_list = [new_feature[node_id_new]]
#         if self.last:
#             new_list.append(F.normalize(h[node_id_new]))
#         if self.ori:
#             new_list.append(F.normalize(self.en_embedding.weight))
#         new_em = self.linear_1(torch.cat(new_list, 1))
#         return new_em, h_candidate

#     def generate_new_embedding(self, new_graph, evolve_embs, t_index=None, intra_k=None, inter_k=None):
#         new_feature = self.refine_graph1(
#             new_graph, evolve_embs, t_index, intra_k, inter_k)
#         return new_feature

#     def refine_graph1(self, graph, feature, t_index=None, intra_k=None, inter_k=None):
#         # Latent relation grpah
#         row1, col1 = graph.all_edges()
#         num_nodes = graph.num_nodes()
#         nodes = torch.arange(num_nodes)
#         feature1 = self.linear_s1(feature)
#         feature2 = self.linear_s2(feature)
#         sim_adj = self.build_sim(feature1, feature2)
#         sim_adj[torch.arange(feature.shape[0]),
#                 torch.arange(feature.shape[0])] = -1e9
#         if not self.ori:
#             sim_adj[row1, col1] = -1e9
#         if intra_k > 0 or inter_k > 0:
#             graph = dgl.DGLGraph().to(feature.device)
#             graph.add_nodes(num_nodes)
#             # intra-time latent relations
#             if intra_k > 0:
#                 intra_mask = t_index.unsqueeze(1) != t_index
#                 intra_row, intra_col = generate_adj(
#                     sim_adj, intra_mask, intra_k, num_nodes, feature.device)
#                 graph.add_edges(torch.cat([intra_row, intra_col]), torch.cat([intra_col, intra_row]),
#                                 data={'etype': torch.LongTensor(self.num_rels * np.ones(len(intra_col) * 2)).to(
#                                     feature.device)})
#             # inter-time latent relations
#             if inter_k > 0:
#                 inter_mask = t_index.unsqueeze(1) == t_index
#                 inter_row, inter_col = generate_adj(
#                     sim_adj, inter_mask, inter_k, num_nodes, feature.device)
#                 graph.add_edges(torch.cat([inter_row, inter_col]), torch.cat([inter_col, inter_row]),
#                                 data={'etype': torch.LongTensor((self.num_rels+1) * np.ones(len(inter_col) * 2)).to(
#                                     feature.device)})

#         graph.edata['r_h'] = self.rel_embedding(graph.edata['etype'])
#         new_feature = self.decoder_f(graph, feature)
#         return new_feature

#     def build_sim(self, context1, context2):
#         context1_norm = context1.div(torch.norm(
#             context1, p=2, dim=-1, keepdim=True))
#         context2_norm = context1.div(torch.norm(
#             context2, p=2, dim=-1, keepdim=True))
#         sim = torch.mm(context1_norm, context2_norm.transpose(1, 0))
#         return sim

#     def fuse_function(self, new_embedding, old_embedding, node_id=None):
#         if self.fuse_g == 'con':
#             return self.linear_fuse(torch.cat((new_embedding, old_embedding), -1))
#         elif self.fuse_g == 'gate':
#             return self.gate(new_embedding, old_embedding, node_id)
#         elif self.fuse_g == 'att':
#             return self.fuse_attention1(new_embedding, old_embedding)
#         elif self.fuse_g == 'ori':
#             return old_embedding
#         elif self.fuse_g == 'new':
#             return new_embedding
#         else:
#             return new_embedding + old_embedding

#     def fuse_attention1(self, s_embedding, l_embedding):
#         w1 = self.fuse_f(torch.tanh(self.linear_s(s_embedding)))
#         w2 = self.fuse_f(torch.tanh(self.linear_l(l_embedding)))
#         aff = F.softmax(torch.cat((w1, w2), 1), 1)
#         en_embedding = aff[:, 0].unsqueeze(
#             1) * s_embedding + aff[:, 1].unsqueeze(1) * l_embedding
#         return en_embedding

#     def reset_parameters(self):
#         # stdv = 1.0 / math.sqrt(self.hidden_size)
#         # for weight in self.parameters():
#         #     weight.data.uniform_(-stdv, stdv)
#         gain = nn.init.calculate_gain('relu')
#         for weight in self.parameters():
#             if len(weight.shape) > 1:
#                 nn.init.xavier_normal_(weight, gain=gain)


# RGCN net
class RGCN(nn.Module):
    def __init__(self, in_dim, out_dim, num_rels, layer_num, low_memory=False):
        super(RGCN, self).__init__()
        self.h_dim = in_dim
        self.num_rels = num_rels
        self.layer_num = layer_num
        self.layer = nn.ModuleList(RelGraphConv(self.h_dim, self.h_dim, num_rels=self.num_rels, regularizer='basis',
                                                low_mem=low_memory, dropout=0.2, activation=F.relu())
                                   for _ in range(self.layer_num))

    def forward(self, graph, features, etypes):
        for conv in self.layer:
            features = conv(graph, features, etypes)
        return features


# class GNN(nn.Module):
#     def __init__(self, in_dim, hidden_dim, out_dim):
#         super(GNN, self).__init__()
#         self.layer1 = dgl.nn.GraphConv(in_dim, hidden_dim).cuda()
#         print("FINISHED_LAYER1_INIT")
#         self.layer2 = dgl.nn.GraphConv(hidden_dim, out_dim).cuda()
#         print("FINISHED_LAYER2_INIT")
#         self.fc = nn.Linear(out_dim, out_dim).to("cuda")

#         self.layer1_norm = nn.LayerNorm(hidden_dim).cuda()
#         self.layer2_norm = nn.LayerNorm(out_dim).cuda()
#         self.layer3_norm = nn.LayerNorm(in_dim).cuda()
#         self.dropout = nn.Dropout(0.3)

#     def forward(self, g, node_feat):

#         node_feat = self.layer1(g, node_feat)
#         node_feat = self.layer1_norm(node_feat)
#         node_feat = self.dropout(node_feat)

#         node_feat = self.layer2(g, node_feat)
#         node_feat = self.layer2_norm(node_feat)
#         node_feat = self.dropout(node_feat)

#         node_feat = self.fc(node_feat)
#         node_feat = self.layer3_norm(node_feat)

#         return node_feat


class GNN(nn.Module):
    def __init__(self, in_dim, out_dim, layer_num, gnn='rgcn', num_rels=None,
                 attn_drop=0.25, feat_drop=0.3, num_head=None, low_memory=False):
        super(GNN, self).__init__()
        self.h_dim = in_dim
        self.out_dim = out_dim
        self.layer_num = layer_num
        self.num_rels = num_rels
        self.gnn = gnn
        self.attn_drop = attn_drop
        self.feat_drop = feat_drop

        self.fc_dropout = nn.Dropout(feat_drop)
        self.fc_layer = nn.Linear(in_dim, in_dim, bias=True)
        self.fc_norm = nn.LayerNorm(self.h_dim)

        self.out_norm = nn.LayerNorm(self.out_dim)

        self.layer_norm = nn.ModuleList(
            [nn.LayerNorm(self.h_dim) for _ in range(self.layer_num)]
        )

        if self.gnn == 'rgcn':
            self.layer = nn.ModuleList(RelGraphConv(self.h_dim, self.h_dim, num_rels=self.num_rels, regularizer='basis',
                                                    num_bases=100, low_mem=low_memory, dropout=0.5, activation=F.relu)
                                       for _ in range(self.layer_num))
        elif self.gnn == 'gat':
            self.layer = nn.ModuleList(
                GATConv(self.h_dim, int(self.h_dim / num_head), num_head, feat_drop=self.feat_drop, attn_drop=self.attn_drop,
                        activation=F.elu)
                for _ in range(self.layer_num))
        elif self.gnn == 'gcn':
            self.layer = nn.ModuleList(GraphConv(self.h_dim, self.h_dim, norm='none', activation=F.relu)
                                       for _ in range(self.layer_num))
        elif self.gnn == 'rgat':
            self.layer = nn.ModuleList(
                FastRGATLayer(self.h_dim, self.h_dim, self.feat_drop, self.attn_drop, self.gnn) for _ in
                range(self.layer_num))
            # self.layer = nn.ModuleList(
            #     FastRGATLayer(self.h_dim, self.h_dim, 0 if i ==
            #                   0 else self.feat_drop, self.attn_drop, self.gnn)
            #     for i in range(self.layer_num))

    def forward(self, graph, feature, etypes=None, edge_weight=None):
        for conv, norm in zip(self.layer, self.layer_norm):
            if self.gnn == 'rgcn':
                feature = conv(graph, feature, etypes)
            elif self.gnn in ['rgat', 'rgat_t']:
                feature = conv(graph, feature)
            elif self.gnn in ['gat']:
                feature = conv(graph, feature)
            elif self.gnn in ['gcn']:
                feature = conv(graph, feature, edge_weight=edge_weight)
            feature = norm(feature)

        fc_feature = self.fc_dropout(feature)
        fc_feature = self.fc_layer(feature)
        fc_feature = self.fc_norm(fc_feature)

        feature = feature + fc_feature
        # feature = self.out_norm(feature)

        return feature

    def generate_graph(self, graph, feature, etypes=None):
        for conv in self.layer:
            if self.gnn == 'rgcn':
                features = conv(graph, feature, etypes)
            elif self.gnn == 'rgat':
                features = conv(graph, feature)
        graph.ndata['z'] = features
        graph.edata['r_h'] = conv.fc_r(graph.edata['r_h'])
        graph.apply_edges(conv.edge_attention)
        return graph


class FastRGATLayer(nn.Module):
    def __init__(self, in_dim, out_dim, feat_drop=0.3, attn_drop=0.3, gnn='rgat_r'):
        super(FastRGATLayer, self).__init__()
        self.gnn = gnn
        if self.gnn in ['rgat', 'rgat_t']:
            self.attn_src = nn.Linear(out_dim, 1, bias=False)
            self.attn_dst = nn.Linear(out_dim, 1, bias=False)
            self.attn_rel = nn.Linear(out_dim, 1, bias=False)
            self.fc = nn.Linear(in_dim, out_dim, bias=False)
            self.fc_r = nn.Linear(in_dim, out_dim, bias=False)
        self.loop_weight = nn.Parameter(torch.Tensor(out_dim, out_dim))
        self.reset_parameters()
        self.feat_drop = nn.Dropout(feat_drop)
        self.atten_drop = nn.Dropout(attn_drop)
        self.h_dim = out_dim

    def reset_parameters(self):
        """Reinitialize learnable parameters."""
        gain = nn.init.calculate_gain('relu')
        if self.gnn in ['rgat', 'rgat_t']:
            nn.init.xavier_uniform_(self.fc.weight, gain=gain)
            nn.init.xavier_uniform_(self.attn_src.weight, gain=gain)
            nn.init.xavier_uniform_(self.attn_dst.weight, gain=gain)
            nn.init.xavier_uniform_(self.attn_rel.weight, gain=gain)
        nn.init.xavier_uniform_(self.loop_weight, gain=gain)

    def forward(self, g, h):
        with g.local_scope():
            h_src = h_dst = self.feat_drop(h)
            if self.gnn in ['rgat', 'rgat_t']:
                feat_src = self.fc(h_src)
                feat_dst = self.fc(h_dst)
            else:
                feat_src = h_src
                feat_dst = h_dst
            el = self.attn_src(feat_src)
            er = self.attn_dst(feat_dst)
            r = self.attn_rel(g.edata['r_h'])
            g.srcdata.update({'ft': feat_src, 'el': el})
            g.dstdata.update({'er': er})
            g.apply_edges(fn.u_add_v('el', 'er', 'e'))
            g.edata['e'] = r + g.edata['e']
            e = F.leaky_relu(g.edata.pop('e'))
            g.edata['a'] = self.atten_drop(edge_softmax(g, e))
            g.update_all(fn.u_mul_e('ft', 'a', 'm'), fn.sum('m', 'ft'))
            rst = g.dstdata['ft']
            return F.relu(rst+h_dst)


class GatingMechanism(nn.Module):
    def __init__(self, entity_num, hidden_dim):
        super(GatingMechanism, self).__init__()
        # gating 的参数

        self.gate_theta = nn.Parameter(torch.empty(entity_num, hidden_dim))
        nn.init.xavier_uniform_(self.gate_theta)

        # self.dropout = nn.Dropout(self.params.dropout)

    def forward(self, X: torch.FloatTensor, Y: torch.FloatTensor, Node_id=None):
        '''
        :param X:   LSTM 的输出tensor   |E| * H
        :param Y:   Entity 的索引 id    |E|,
        :return:    Gating后的结果      |E| * H
        '''
        gate = torch.sigmoid(self.gate_theta)[Node_id]
        output = torch.mul(gate, X) + torch.mul(-gate + 1, Y)
        return output


def generate_adj(adj, mask, k, h, device):
    nodes = torch.arange(h)
    adj[mask] = -1e9
    adj[adj < 0] = 0
    if h > k:
        val, ind = torch.topk(adj, k, dim=-1)
        new_row = nodes.unsqueeze(1).repeat(1, k).reshape(-1).to(device)
    else:
        val, ind = torch.topk(adj, h, dim=-1)
        new_row = nodes.unsqueeze(1).repeat(1, h).reshape(-1).to(device)
    new_col = ind.reshape(-1)
    return new_row, new_col
