import torch
import dgl

from dgl.nn.pytorch.conv import EGATConv
from typing import List, Dict, Literal, Optional
from transformers import PreTrainedTokenizer, PreTrainedModel

from .repr_tools import get_hidden_state
from .mega_hparams import MEGAHyperParams

import torch.nn as nn


# 读取json文件并构建图
def build_graph_from_triples(
        triples: list,
        model: PreTrainedModel,
        tokenizer: PreTrainedTokenizer,
        hparams: MEGAHyperParams
) -> (dgl.DGLGraph(), list):

    # 创建图
    g = dgl.DGLGraph()
    g = g.to("cuda")

    nodes = {}
    edges = []
    node_features = []
    edge_features = []

    # 在遍历三元组之前，创建关系类型到索引的映射
    relation_types = set(triple['relation'] for triple in triples)
    relation_to_id = {rel: idx for idx, rel in enumerate(relation_types)}
    edge_type_ids = []  # 初始化边类型ID的列表

    # 创建关系的初始化嵌入列表
    init_rel_emb = []

    for idx, relation in enumerate(relation_types):
        relation_vec = get_hidden_state(model, tokenizer, hparams, relation)
        # relation_vec = torch.zeros(1600, dtype=torch.float)
        init_rel_emb.append(relation_vec)

    # 将其转化为张量
    init_rel_emb = torch.stack(init_rel_emb).to("cuda")

    print("Iterating Triples")

    for triple in triples[:hparams.subgraph_size]:
        subject_str = triple['subject']
        relation_str = triple['relation']
        target_str = triple['target']
        # modified here to test GraphConv
        # target_str = relation_str+triple['target']

        subject_vec = get_hidden_state(model, tokenizer, hparams, subject_str)
        # relation_vec = get_hidden_state(
        #     model, tokenizer, hparams, relation_str)
        relation_vec = init_rel_emb[relation_to_id[relation_str]]
        target_vec = get_hidden_state(model, tokenizer, hparams, target_str)

        if subject_str not in nodes:
            nodes[subject_str] = subject_vec
        if target_str not in nodes:
            nodes[target_str] = target_vec

        edges.append((subject_str, target_str))
        # edge_features.append(init_rel_emb[relation_to_id[relation_str]])
        edge_features.append(relation_vec)

        edge_type_ids.append(relation_to_id[relation_str])

    # nodes[triples[0]['target']] = torch.zeros_like(nodes[triples[0]['target']])

    # 按顺序进行添加
    nodes_list = list(nodes.keys())
    nodes_list.sort()  # sort it to ensure the same order for each run
    node_indices = {node: index for index, node in enumerate(nodes_list)}
    node_features = [nodes[node] for node in nodes_list]

    edges = [(node_indices[v], node_indices[u]) for u, v in edges]

    g.add_nodes(len(nodes_list))
    g.add_edges(*zip(*edges))

    g.ndata['feat'] = torch.stack([n.cpu()
                                  for n in node_features]).to("cuda")
    g.ndata['id'] = torch.tensor([node_indices[n]
                                 for n in nodes_list], dtype=torch.long).to("cuda")
    g.edata['r_h'] = torch.stack([e.cpu()
                                  for e in edge_features]).to("cuda")
    g.edata['etype'] = torch.tensor(edge_type_ids, dtype=torch.long).to("cuda")

    g = dgl.add_self_loop(g)
    indegrees = g.in_degrees().float()
    node_norm = torch.pow(indegrees, -1)

    g.ndata['norm'] = node_norm.view(-1, 1).to("cuda")

    print("Finished building graph")

    # init_rel_emb = torch.zeros(1)

    return g, node_indices, init_rel_emb

# 设置EGAT网络


# class GNN(nn.Module):
#     def __init__(self, in_dim, hidden_dim, out_dim):
#         super(GNN, self).__init__()
#         self.layer1 = dgl.nn.GraphConv(in_dim, hidden_dim).cuda()
#         print("FINISHED_LAYER1_INIT")
#         self.layer2 = dgl.nn.GraphConv(hidden_dim, out_dim).cuda()
#         print("FINISHED_LAYER2_INIT")
#         self.fc = nn.Linear(out_dim, out_dim).to("cuda")

#         self.layer1_norm = nn.LayerNorm(hidden_dim).cuda()
#         self.layer2_norm = nn.LayerNorm(out_dim).cuda()
#         self.layer3_norm = nn.LayerNorm(in_dim).cuda()
#         self.dropout = nn.Dropout(0.3)

#     def forward(self, g, node_feat):

#         node_feat = self.layer1(g, node_feat)
#         node_feat = self.layer1_norm(node_feat)
#         node_feat = self.dropout(node_feat)

#         node_feat = self.layer2(g, node_feat)
#         node_feat = self.layer2_norm(node_feat)
#         node_feat = self.dropout(node_feat)

#         node_feat = self.fc(node_feat)
#         node_feat = self.layer3_norm(node_feat)

#         return node_feat


# class EGATNet(nn.Module):
#     def __init__(self, in_dim, hidden_dim, out_dim, num_heads):
#         super(EGATNet, self).__init__()
#         self.num_heads = num_heads

#         # 第一层EGAT卷积
#         self.layer1 = EGATConv(
#             in_dim, in_dim, hidden_dim // num_heads,
#             hidden_dim // num_heads, num_heads
#         ).to("cuda")

#         # 第二层EGAT卷积
#         self.layer2 = EGATConv(
#             hidden_dim, hidden_dim, out_dim // num_heads,
#             out_dim // num_heads, num_heads
#         ).to("cuda")

#         # 第一层的线性层和LayerNorm层
#         self.layer1_node_fc = nn.Linear(hidden_dim, hidden_dim).to("cuda")
#         self.layer1_edge_fc = nn.Linear(hidden_dim, hidden_dim).to("cuda")

#         self.layer1_node_norm = nn.LayerNorm(hidden_dim).to("cuda")
#         self.layer1_edge_norm = nn.LayerNorm(hidden_dim).to("cuda")
#         self.layer1_fc_node_norm = nn.LayerNorm(hidden_dim).to("cuda")
#         self.layer1_fc_edge_norm = nn.LayerNorm(hidden_dim).to("cuda")

#         # 第二层的线性层和LayerNorm层
#         self.layer2_node_fc = nn.Linear(out_dim, out_dim).to("cuda")

#         self.layer2_node_norm = nn.LayerNorm(out_dim).to("cuda")
#         self.layer2_fc_norm = nn.LayerNorm(out_dim).to("cuda")

#         # Dropout
#         self.dropout = nn.Dropout(0.3)

#     def forward(self, g, node_feat, edge_feat):
#         # 第一层图卷积
#         node_feat, edge_feat = self.layer1(g, node_feat, edge_feat)
#         node_feat = node_feat.view(node_feat.size(0), -1)
#         edge_feat = edge_feat.view(edge_feat.size(0), -1)
#         node_feat = self.layer1_node_norm(node_feat)
#         edge_feat = self.layer1_edge_norm(edge_feat)

#         node_feat = self.dropout(node_feat)
#         edge_feat = self.dropout(edge_feat)

#         node_feat = self.layer1_node_fc(node_feat)
#         edge_feat = self.layer1_edge_fc(edge_feat)
#         node_feat = self.layer1_fc_node_norm(node_feat)
#         edge_feat = self.layer1_fc_edge_norm(edge_feat)

#         node_feat = self.dropout(node_feat)
#         edge_feat = self.dropout(edge_feat)

#         # 第二层图卷积
#         node_feat, edge_feat = self.layer2(g, node_feat, edge_feat)
#         node_feat = node_feat.view(node_feat.size(0), -1)
#         node_feat = self.layer2_node_norm(node_feat)

#         node_feat = self.dropout(node_feat)

#         node_feat = self.layer2_node_fc(node_feat)
#         node_feat = self.layer2_fc_norm(node_feat)

#         return node_feat


def check_device(model):
    for name, param in model.named_parameters():
        print('Layer:', name, 'Device:', param.device)


# def get_subject_feature(
#         triples: list,
#         gnn_model: GNN,
#         model: PreTrainedModel,
#         tokenizer: PreTrainedTokenizer,
#         hparams: MEGAHyperParams
# ) -> torch.Tensor:

#     n_embed = model.config.n_embd if hasattr(
#         model.config, "n_embed") else model.config.hidden_size

#     print("n_embed={}".format(n_embed))

#     # return torch.zeros(n_embed)  # check

#     # 提取图结构与特征
#     g, node_indices = build_graph_from_triples(
#         triples, model, tokenizer, hparams)

#     # 创建EGAT模型并进行初始化
#     print("Initializing EGATNet")
#     num_heads = 8  # 注意力头数量，还得再调

#     print("DEVICE_OF_ORIGINAL_GRAPH:{}".format(g.device))
#     egat_net = GNN(n_embed, n_embed, n_embed, num_heads)

#     check_device(egat_net)

#     node_features = g.ndata['r_h']
#     edge_features = g.edata['feat']

#     print("DEVICE_OF_FEATURES:{}".format(node_features.device))

#     # 运行EGAT模型
#     print("RUN EGATNet Model")
#     outputs = egat_net(g, node_features.float(), edge_features.float())

#     # 获取JSON的第一条数据的subject在EGAT中对应的节点的特征

#     subject_str = triples[0]['subject']
#     subject_id = node_indices[subject_str]

#     subject_feature = outputs[subject_id]
#     print("subject_feature extracted")

#     return subject_feature.detach()
