# %%
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import json
import numpy as np
import openke_lib
from openke_lib.config import Trainer, Tester
from openke_lib.module.model import TransE
from openke_lib.module.loss import MarginLoss
from openke_lib.module.strategy import NegativeSampling
from openke_lib.data import TrainDataLoader, TestDataLoader

from data.rule import get_RULES
import networkx as nx
import matplotlib.pyplot as plt

def test_link_prediction(CogKE, data_loader, use_gpu):
    # test the model
    CogKE.load_checkpoint('./checkpoint/CogKE.ckpt')
    print('Test model Loaded!')
    tester = Tester(model = CogKE, data_loader = data_loader, use_gpu = use_gpu)
    tester.run_link_prediction(type_constrain = False)  

def test_triple_classification(CogKE, data_loader, use_gpu):
    # test the model
    CogKE.load_checkpoint('./checkpoint/CogKE.ckpt')
    tester = Tester(model = CogKE, data_loader = data_loader, use_gpu = True)
    acc, threshlod = tester.run_triple_classification()
    print(acc)

def set_color(ReferGraph, V_knw, V_mem):
    color_map = []
    for node in ReferGraph:
        if node in V_knw:
            color_map.append('green')
        elif node in V_mem:
            color_map.append('yellow')
        else: color_map.append('lightblue') 
    return color_map

def FindRuleNodes(G, ent2rule):
    rule_node = set()
    for node in G.nodes():
        if node in ent2rule:
            for r in ent2rule[node]:
                rule_node.add(r)   
    return rule_node

def draw_graph(G, color_map=None):
    # pos = nx.spring_layout(G)
    pos = nx.shell_layout(G)
    # pos = nx.circular_layout(G)
    pos = nx.drawing.nx_agraph.graphviz_layout(G, prog='dot')

    nx.draw(G, pos, node_color=color_map)
    node_labels = nx.get_node_attributes(G, 'desc')
    nx.draw_networkx_labels(G, pos, labels=node_labels)
    edge_labels = nx.get_edge_attributes(G, 'name')
    nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels)
    plt.show()

def get_ReferGraph(initial_state,ent2id, ent2rule, rule_dict):
    '''PseudoCode for Constructing InferGraph
    规则字典 {规则名:([输入状态], [输出状态])}

    输入初始状态
    初始化推理子图G,将初始结点添加到G中
    初始化规则结点集V_mem
    初始化现有结点集V_cur, 依据初始状态节点与规则节点的连接情况
    While V_cur - V_mem不为空：  
        根据V_cur - V_mem中的规则记录的输入输出情况，将相关结点和边(值和名字)添加到G中 # 需要规则字典RULES
        把V_cur添加到V_mem中
        更新V_cur ：根据当前G中的结点，去CogGraph中查找连接规则结点集V_cur; # 需要实体名2规则名集合
    ''' 
    V_knw = set()
    V_cur, V_mem = set(), set()
    ReferGraph = nx.DiGraph()
    for i in initial_state:
        V_knw.add(ent2id[i])
        ReferGraph.add_node(ent2id[i], desc=ent2id[i])
        # draw_graph(ReferGraph)
        if i in ent2rule:
            for j in ent2rule[i]:
                V_cur.add(j)
        
    draw_graph(ReferGraph)
    while (V_cur - V_mem):
        V_extra = V_cur - V_mem
        for v in V_extra:
            ReferGraph.add_node(v, desc=v)
            # draw_graph(ReferGraph)
            rule = rule_dict[v]
            for rule_in in rule[0]:
                if ent2id[rule_in] not in ReferGraph:
                    ReferGraph.add_node(ent2id[rule_in], desc=ent2id[rule_in])
                ReferGraph.add_edge(ent2id[rule_in], v, name=v)
                # draw_graph(ReferGraph)
            for rule_out in rule[1]:
                if ent2id[rule_out] not in ReferGraph:
                        ReferGraph.add_node(ent2id[rule_out], desc=ent2id[rule_out])
                ReferGraph.add_edge(v, ent2id[rule_out], name=v)
                # draw_graph(ReferGraph)
        V_mem = V_cur | V_mem
        V_cur = FindRuleNodes(ReferGraph,ent2rule)

    return ReferGraph, V_knw, V_mem

def check_triple(in_no, V_knw, id2embed_ent, id2embed_rel, id2rel, id2ent):
    V_knw_tmp = V_knw.copy()
    for knw in V_knw_tmp:
        for idx_rel, vec_rel in enumerate(id2embed_rel):
            vec_sub = id2embed_ent[knw]
            # vec_rel = id2embed_rel[rel]
            vec1 = np.add(vec_sub, vec_rel)
            vec2 = id2embed_ent[in_no]
        
            dist = float(np.dot(vec1,vec2)/(np.linalg.norm(vec1)*np.linalg.norm(vec2)))
            # print(f'dist:{dist}')
            if (dist > 0.5):
                print(f'{id2ent[knw]} + {id2rel[idx_rel]} -> {id2ent[in_no]}')
                return True, dist
    return False, 0

def update_ReferGraph(G, V_knw, V_mem, id2embed_ent, id2embed_rel,id2rel, id2ent, rule_dict):
    Inferred_nodes = set()
    V_cur = set()
    V_cur = V_cur | V_knw
    cnt = 1
    while True:
        print(f'--------Inferring Iteration {cnt}--------')
        V_knw = V_cur.copy()
        for node in V_mem:
            in_nodes = [i[0] for i in G.in_edges(node)]
            out_node = [i[1] for i in G.out_edges(node)][0]
            # print(in_nodes, out_node)
            for in_no in in_nodes:
                if (in_no not in V_cur) and check_triple(in_no, V_cur, id2embed_ent, id2embed_rel,id2rel, id2ent):
                    V_cur.add(in_no)
                    Inferred_nodes.add(in_no)
                    print(f'Link Prediction Applied! Add Node {in_no}:{id2ent[in_no]}.')

            if (not set(in_nodes) - V_cur) and (out_node not in V_cur):
                print(f'Rule {node}:{rule_dict[node]} Applied! Add Node {out_node}:{id2ent[out_node]}.')
                V_cur.add(out_node)
                Inferred_nodes.add(out_node)

        cnt += 1
        if not V_cur - V_knw:
            print('Finish Inferring!')
            break
    return Inferred_nodes, V_cur

def update_ReferGraph_prob(G, V_knw_prob, V_mem, id2embed_ent, id2embed_rel,id2rel, id2ent, ent2id, rule_dict):
    Inferred_nodes = {}
    V_cur = V_knw_prob.copy()
    print(f'Inferred_nodes_prob:{Inferred_nodes}')
    print(f'V_cur_prob:{V_cur}')

    cnt = 1
    while True:
        print(f'--------Inferring Iteration {cnt}--------')
        V_knw_prob = V_cur.copy()
        for node in V_mem:
            in_nodes = [i[0] for i in G.in_edges(node)]
            out_node = [i[1] for i in G.out_edges(node)][0]
            # print(in_nodes, out_node)
            for in_no in in_nodes:
                if (in_no not in V_cur):
                    isTrue, prob = check_triple(in_no, V_cur, id2embed_ent, id2embed_rel,id2rel, id2ent)
                    if isTrue:
                        V_cur[in_no] = prob
                        Inferred_nodes[in_no] = prob
                        print(f'Link Prediction Applied! Add Node {in_no}:{id2ent[in_no]} with prob {prob}.')

            if (not set(in_nodes) - set(V_cur.keys())) and (out_node not in V_cur):
                rule_prob = min([V_cur[ent2id[i]] for i in rule_dict[node][0]])
                print(f'Rule {node}:{rule_dict[node]} Applied! Add Node {out_node}:{id2ent[out_node]} with prob {rule_prob}.')
                V_cur[out_node] = rule_prob
                Inferred_nodes[out_node] = rule_prob

        cnt += 1
        if not set(V_cur.keys()) - set(V_knw_prob.keys()):
            print('Finish Inferring!')
            break
    return Inferred_nodes, V_cur

def train(CogKE, data_loader, train_times, alpha, use_gpu):
    # define the loss function
    CogKE_NS = NegativeSampling(
        model = CogKE, 
        loss = MarginLoss(margin = 5.0),
        batch_size = data_loader.get_batch_size()
    ) 

    # train the model
    trainer = Trainer(model = CogKE_NS, data_loader = data_loader, train_times = train_times, alpha = alpha, use_gpu = use_gpu)
    trainer.run()
    CogKE.save_checkpoint('./checkpoint/CogKE.ckpt')
    CogKE.save_parameters('./embed.vec') # 保存嵌入向量

def infer(initial_state, embedings, ent2id, ent2rule, rule_dict, id2rel, id2ent):
    '''PseudoCode for Inferring on CogGraph
    INPUT: 认知图谱CogGraph, 状态S = 初始状态S_0 = {s_0, s_1, ..., s_n};
    # 对CogGraph进行表示学习, 采用Knowledge Embeddding算法TransE, 得到分布式表示E
    E = TransE(CogGraph);
    构建推理子图
    更新推理子图
    RETURN 状态S;
    '''
    id2embed_ent = embedings['ent_embeddings.weight']
    print(f'id2embed_ent size:{len(id2embed_ent)}')
    id2embed_rel = embedings['rel_embeddings.weight']

    dist_list = []
    for (sub, obj, rel) in triples:
        
        vec_sub = id2embed_ent[ent2id[sub]]
        vec_rel = id2embed_rel[rel2id[rel]]

        vec1 = np.add(vec_sub, vec_rel)
        vec2 = id2embed_ent[ent2id[obj]]
        
        dist = float(np.dot(vec1,vec2)/(np.linalg.norm(vec1)*np.linalg.norm(vec2)))
        dist_list.append(dist)
    
    # 统计gold triple (h, r, t)中 (h+r)与 t 的平均余弦相似度
    print(f'统计gold triple (h, r, t)中 (h+r)与 t 的平均余弦相似度')
    print(f'min cosine dist:{np.min(dist_list)}') 
    print(f'avg cosine dist:{np.mean(dist_list)}') 
    print(f'median cosine dist:{np.median(dist_list)}')
    print(f'max cosine dist:{np.max(dist_list)}') 

    ReferGraph, V_knw, V_mem = get_ReferGraph(initial_state, ent2id=ent2id, ent2rule=ent2rule, rule_dict=rule_dict)
    

    # print(f'已知结点集:{V_knw}')
    # print(f'对应规则集:{V_mem}')
    
    # ## 不带概率 start
    # color_map = set_color(ReferGraph, V_knw, V_mem)
    # draw_graph(ReferGraph, color_map)

    # # for i in nx.weakly_connected_components(ReferGraph):
    # #     print(i)

    # Inferred_nodes, V_knw = update_ReferGraph(ReferGraph, V_knw, V_mem, id2embed_ent, id2embed_rel, id2rel, id2ent, rule_dict)
    
    # color_map = set_color(ReferGraph, V_knw, V_mem)
    # draw_graph(ReferGraph, color_map)
    # ## 不带概率 end

    ### 带概率 start
    V_knw_prob ={i:1.0 for i in V_knw}
    Inferred_nodes, V_knw_prob = update_ReferGraph_prob(ReferGraph, V_knw_prob, V_mem, id2embed_ent, id2embed_rel, id2rel, id2ent,ent2id, rule_dict)
    ### 带概率 end

    print(f'Inferred Nodes:{Inferred_nodes}, {[id2ent[i] for i in Inferred_nodes]}')


#  %%
if __name__ == '__main__':
    # dataloader for training
    train_dataloader = TrainDataLoader(
        in_path = "./data/mini_CMeKG/", 
        nbatches = 200,
        threads = 8, 
        sampling_mode = "normal", 
        bern_flag = 1, 
        filter_flag = 1, 
        neg_ent = 25,
        neg_rel = 0)

    # define the model
    CogKE = TransE( 
        ent_tot = train_dataloader.get_ent_tot(),
        rel_tot = train_dataloader.get_rel_tot(),
        dim = 200, 
        p_norm = 1, 
        norm_flag = True)

    # train(CogKE = CogKE, data_loader = train_dataloader, train_times = 2000, alpha = 1.0, use_gpu = True)

    # dataloader for test link prediction 

    # %%
    # test_dataloader = TestDataLoader(in_path="./data/mini_CMeKG/", sampling_mode="link", type_constrain=True)
    # test_link_prediction(CogKE = CogKE, data_loader = test_dataloader, use_gpu = True)

    # %%
    # dataloader for test triple classification
    # test_dataloader = TestDataLoader("./data/", "link")
    # test_triple_classification(CogKE = CogKE, data_loader = test_dataloader, use_gpu = True)
    
    # %%
    data_path = './data/mini_CMeKG/'
    ent2id = np.load(data_path + 'ent2id.npy', allow_pickle=True).item()
    id2ent = {j:i for i,j in ent2id.items()}
    print(f'ent2id size:{len(ent2id)}')
 
    rel2id = np.load( data_path + 'rel2id.npy', allow_pickle=True).item()
    id2rel = {j:i for i,j in rel2id.items()}

    ent2rule = np.load(data_path + 'ent2rule.npy', allow_pickle=True).item()

    triples = np.load(data_path + 'train_triples.npy', allow_pickle=True).item()

    embed_file = './embed.vec'
    embedings = json.load(open(embed_file, 'r'))
    # embedings = ['zero_const', 'pi_const', 'ent_embeddings.weight', 'rel_embeddings.weight']
    
    # %%
    # initial_state = ['血管性水肿', '瘙痒', '反复发作', '消退迅速', 'IgE检测为阳性', '寒冷敏感检测为阴性', '恶心', '呕吐', '腹痛']

    # extra = ['胸闷','肿胀','腹泻','水肿']
    # extra = []
    # initial_state += extra

    initial_state =["普通感冒","咳嗽","鼻流涕","咳痰量较多","气管炎","肺炎","呼吸道感染","鼻炎"] #小儿支气管炎


    RULES, RULES_symptoms = get_RULES()
    print(len(RULES_symptoms))

    cnt = 0
    for sym in RULES_symptoms:
        if sym not in ent2id:
            cnt += 1
            print(sym)
    print(f'missing symptoms:{cnt}')
    # RULES = align_term(RULES, ent2id) # TO-DO RULE集合中的症状，需要与知识图谱中的术语对齐

    
    # %%
    print(f'initial_state:{initial_state}')
    infer(initial_state, embedings, ent2id, ent2rule, RULES, id2rel, id2ent)
