import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import json
import numpy as np
from utils import *
from tqdm import tqdm
import random

import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--setting', choices=['gene', 'incre'])

def check_triple(in_no, V_knw, id2embed_ent, id2embed_rel, id2rel, id2ent, do_sum, threshold, show_infer_step=False, incre=False):
    V_knw_tmp = V_knw.copy()
    in_no_prob = []
    for knw in V_knw_tmp:
        if V_knw_tmp[knw] > 0:
            for idx_rel, vec_rel in enumerate(id2embed_rel):
                if incre and id2rel[idx_rel] not in ['并发症', '病因', '相关疾病']:
                    continue
                elif id2rel[idx_rel] not in ['focus_of', 'associated_with', 'temporally_related_to']:
                    continue
                else:
                    vec_sub = id2embed_ent[knw]
                    # vec_rel = id2embed_rel[rel]
                    vec1 = np.add(vec_sub, vec_rel)
                    vec2 = id2embed_ent[in_no]
                
                    CosSim = float(np.dot(vec1,vec2)/(np.linalg.norm(vec1)*np.linalg.norm(vec2))) * V_knw_tmp[knw]
                    # print(f'dist:{dist}')
                    if (CosSim >= threshold):
                        if show_infer_step:
                            print(f'{id2ent[knw]} + {id2rel[idx_rel]} -> {id2ent[in_no]}')
                        # return True, dist 
                        in_no_prob.append(CosSim)

    if len(in_no_prob) > 0:
        return True, sum_prob(in_no_prob) if do_sum else max(in_no_prob) # do_sum:  a + b - a * b , else max(prob_list)
    else:
        return False, 0

def get_InferGraph(initial_state, ent2id, ent2rule, rule_dict, show_fig=False):
    '''PseudoCode for Constructing InferGraph
    规则字典 {规则名:([输入状态], [输出状态])}

    输入初始状态
    初始化推理子图G,将初始结点添加到G中
    初始化规则结点集V_mem = ∅
    初始化现有规则结点集V_cur, 依据初始状态节点与规则节点的连接情况
    While V_cur - V_mem不为空：  
        根据V_cur - V_mem中的规则记录的输入输出情况，将相关结点和边(值和名字)添加到G中 # 需要规则字典RULES
        把V_cur添加到V_mem中
        更新V_cur ：根据当前G中的结点，去CogGraph中查找连接规则结点集V_cur; # 需要实体名2规则名集合
    ''' 
    V_knw = {}
    V_cur, V_mem = set(), set()
    InferGraph = nx.DiGraph()
    for i in initial_state:
        V_knw[ent2id[i]] = initial_state[i]
        InferGraph.add_node(ent2id[i], desc=ent2id[i])
        if show_fig:
            draw_graph(InferGraph)
        if i in ent2rule:
            for j in ent2rule[i]:
                V_cur.add(j)
    if show_fig:
        draw_graph(InferGraph)
    while len((V_cur - V_mem)) > 0:
        V_extra = V_cur - V_mem
        for v in V_extra:
            InferGraph.add_node(v, desc=v)
            if show_fig:
                draw_graph(InferGraph)
            rule = rule_dict[v]
            for rule_in in rule[0]:
                if ent2id[rule_in] not in InferGraph:
                    InferGraph.add_node(ent2id[rule_in], desc=ent2id[rule_in])
                InferGraph.add_edge(ent2id[rule_in], v, name=v)
                if show_fig:
                    draw_graph(InferGraph)
            for rule_out in rule[1]:
                if ent2id[rule_out] not in InferGraph:
                        InferGraph.add_node(ent2id[rule_out], desc=ent2id[rule_out])
                InferGraph.add_edge(v, ent2id[rule_out], name=v)
                if show_fig:
                    draw_graph(InferGraph)
        V_mem = V_cur | V_mem
        V_cur = FindRuleNodes(InferGraph,ent2rule)

    return InferGraph, V_knw, V_mem

def update_InferGraph_prob(G, V_knw_prob, V_mem, id2embed_ent, id2embed_rel,id2rel, id2ent, ent2id, rule_dict, do_sum, threshold, incre_setting=None, show_infer_step=False):
    Inferred_nodes = {}
    V_cur = V_knw_prob.copy()

    V_mem.sort(key=lambda x: int(x[1:])) 

    cnt = 1
    visited_rule = set()
    while True:
        if show_infer_step:
            print(f'--------Inferring Iteration {cnt}--------')
        V_knw_prob = V_cur.copy()
        for node in V_mem:
            in_nodes = [i[0] for i in G.in_edges(node)]
            out_node = [i[1] for i in G.out_edges(node)][0]
            early_break = False
            for i in in_nodes:
                if i in V_knw_prob and V_knw_prob[i] == -1:
                    early_break = True
                    break
            if early_break:
                continue
            elif len(set(in_nodes) - set(V_cur.keys())) == 0: # all premises are satisfied
                rule_flag = True
            else: # do link prediction for premise
                rule_flag = True
                for in_no in set(in_nodes) - set(V_cur.keys()):
                    isTrue, prob = check_triple(in_no, V_cur, id2embed_ent, id2embed_rel, id2rel, id2ent, do_sum=do_sum, threshold=threshold, show_infer_step=show_infer_step)
                    if isTrue:
                        V_cur[in_no] = prob
                        Inferred_nodes[in_no] = prob
                        if show_infer_step:
                            print(f'Link Prediction Applied! Add Node {in_no}:{id2ent[in_no]} with prob {prob}.')
                    else:
                        rule_flag = False
                        break

            if rule_flag and node not in visited_rule:
                visited_rule.add(node)
                rule_prob = min([V_cur[ent2id[i]] for i in rule_dict[node][0]]) * rule_dict[node][2]

                if show_infer_step:
                    print(f'Rule {node}:{rule_dict[node]} Applied! Add Node {out_node}:{id2ent[out_node]} with prob {rule_prob}.')
                if out_node not in V_cur:
                    V_cur[out_node] = rule_prob
                    Inferred_nodes[out_node] = rule_prob
                elif V_cur[out_node] not in [1.0, -1.0]:
                    if rule_prob > 0 and V_cur[out_node] > 0:
                        V_cur[out_node] =  rule_prob + V_cur[out_node] - rule_prob * V_cur[out_node]
                    elif rule_prob < 0 and V_cur[out_node] < 0:
                        V_cur[out_node] =  rule_prob + V_cur[out_node] + rule_prob * V_cur[out_node]
                    else:
                        V_cur[out_node] = (rule_prob + V_cur[out_node]) / (1 - min(abs(rule_prob), abs(V_cur[out_node])))
                    assert V_cur[out_node] <= 1 and V_cur[out_node] >= -1
                    Inferred_nodes[out_node] = V_cur[out_node]

        cnt += 1
        if len(set(V_cur.keys()) - set(V_knw_prob.keys())) == 0:
            # print('Finish Inferring!')
            break

    if incre_setting is not None:
        # incre_node = ent2id['小儿支气管炎']
        incre_node = ent2id[incre_setting]
        # print(f'incre_node:{incre_node}')

        V_knw_dises = {k:v for k,v in V_cur.items() if id2ent[k] in ["小儿腹泻", "小儿支气管炎", "小儿感冒", "小儿消化不良"]}

        isTrue, prob = check_triple(incre_node, V_knw_dises, id2embed_ent, id2embed_rel, id2rel, id2ent, do_sum=True, show_infer_step=show_infer_step, threshold=0, incre=True)

        if isTrue:
            V_cur[incre_node] = prob
            Inferred_nodes[incre_node] = prob
            if show_infer_step:
                print(f'Link Prediction Applied! Add Node {incre_node}:{id2ent[incre_node]} with prob {prob}.')

    return Inferred_nodes, V_cur


def infer(initial_state, embeddings, ent2id, ent2rule, rule_dict, id2rel, id2ent, do_sum, threshold, show_fig=False, show_infer_step=False, incre_setting=None, show_statistics=False, rel2id=None, triples=None):
    '''PseudoCode for Inferring on CogGraph
    INPUT: 认知图谱CogGraph, 状态S = 初始状态S_0 = {s_0, s_1, ..., s_n};
    # 对CogGraph进行表示学习, 采用Knowledge Embeddding算法TransE, 得到分布式表示E
    E = TransE(CogGraph);
    构建推理子图
    更新推理子图
    RETURN 状态S;
    '''

    id2embed_ent, id2embed_rel = embeddings.solver.values()

    ReferGraph, V_knw, V_mem = get_InferGraph(initial_state, ent2id=ent2id, ent2rule=ent2rule, rule_dict=rule_dict, show_fig=show_fig)

    V_mem = sorted(list(V_mem))

    # print(f'已知结点集:{V_knw}')
    # print(f'对应规则集:{V_mem}')
    
    # ## 不带概率 start
    # color_map = set_color(ReferGraph, V_knw, V_mem)
    # draw_graph(ReferGraph, color_map)

    # # for i in nx.weakly_connected_components(ReferGraph):
    # #     print(i)

    # Inferred_nodes, V_knw = update_ReferGraph(ReferGraph, V_knw, V_mem, id2embed_ent, id2embed_rel, id2rel, id2ent, rule_dict)
    
    # color_map = set_color(ReferGraph, V_knw, V_mem)
    # draw_graph(ReferGraph, color_map)
    # ## 不带概率 end

    ### 带概率 start
    V_knw_prob ={i:{'True':1.0, 'False':-1.0}[V_knw[i]] for i in V_knw}
    Inferred_nodes, V_knw_prob = update_InferGraph_prob(ReferGraph, V_knw_prob, V_mem, id2embed_ent, id2embed_rel, id2rel, id2ent,ent2id, rule_dict, do_sum, threshold, incre_setting=incre_setting, show_infer_step=show_infer_step)
    ### 带概率 end
    # print(f'Inferred Nodes:{Inferred_nodes}, {[id2ent[i] for i in Inferred_nodes]}')
    return {id2ent[k]:v for k,v in Inferred_nodes.items()}


def k_fold_cross_val(data_path, rule_path, embeddings, ent2id, id2rel, id2ent):
    acc_list = []
    hits_1_list, hits_2_list = [], []
    mrr_list = []
    valid_acc_list = []
    for i in range(10):
        data, symptoms, diseases = load_data(data_path, K_fold=i, change=None)
        print(f'\nTest General Setting Fold-{i} with {len(data)} Samples ...')

        rule_dict = load_rules(rule_path,  K_fold=i)
        ent2rule = get_ent2rule(rule_dict)

        correct_cnt = 0
        hits_1_cnt, hits_2_cnt = 0, 0
        na_cnt = 0
        MRR = []
        f = open('tmp.json', 'w')
        for item in tqdm(data):
            initial_state = item['symptoms']
            goal = item['disease']


            conclusions = infer(initial_state, embeddings, ent2id, ent2rule, rule_dict, id2rel, id2ent, do_sum=False)


            pred_dise, pred_certainty = get_most_likely_disease(conclusions, diseases)
            if pred_dise is None:
                na_cnt += 1

            mrr, hits_1, hits_2 = get_MRR(conclusions, goal, diseases)
            if hits_1 == 1:
                f.write(json.dumps(item, ensure_ascii=False) + '\n')

            MRR.append(mrr)
            hits_1_cnt += hits_1
            correct_cnt += hits_1
            hits_2_cnt += hits_2

        print(f'{correct_cnt} of {len(data)} samples can be correctly diagnosed by CogInfer. Accuracy:{correct_cnt/len(data)}')
        print(f'{na_cnt} of {len(data)} samples cannot be diagnosed (no applicable rule) by CogInfer. N/A Rate:{na_cnt/len(data)}')
        print(f'{correct_cnt} of {len(data) - na_cnt} valid samples can be correctly diagnosed by CogInfer. Accuracy:{correct_cnt/(len(data) - na_cnt)}')
        print(f'MRR: {np.mean(MRR)}')
        print(f'Global Hits@1: {hits_1_cnt/len(data)}')  
        print(f'Global Hits@2: {hits_2_cnt/len(data)}')  
        

        acc_list.append(correct_cnt/len(data))
        hits_1_list.append(hits_1_cnt/len(data))
        hits_2_list.append(hits_2_cnt/len(data))  
        mrr_list.append(np.mean(MRR))
        valid_acc_list.append(correct_cnt/(len(data) - na_cnt))


    avg_Acc = np.mean(acc_list)
    avg_Valid_Acc = np.mean(valid_acc_list)
    avg_MRR = np.mean(mrr_list)
    avg_Hits_1 = np.mean(hits_1_list)
    avg_Hits_2 = np.mean(hits_2_list)
    print(f"\n10-Fold avg Acc:{avg_Acc}")
    print(f"10-Fold avg Valid Acc:{avg_Valid_Acc}")
    print(f"10-Fold avg MRR:{avg_MRR}")
    print(f"10-Fold avg Global Hits@1:{avg_Hits_1}")
    print(f"10-Fold avg Global Hits@2:{avg_Hits_2}")


def report_metrcis(data, embeddings, threshold, ent2id, ent2rule, rule_dict, id2rel, id2ent, diseases, noisy_or):
    correct_cnt = 0
    hits_1_cnt, hits_2_cnt = 0, 0
    na_cnt = 0
    MRR = []

    for item in tqdm(data):
        initial_state = item['symptoms']
        goal = item['disease']

        conclusions = infer(initial_state, embeddings, ent2id, ent2rule, rule_dict, id2rel, id2ent, do_sum=noisy_or, threshold = threshold, show_infer_step=False)

        pred_dise, pred_certainty = get_most_likely_disease(conclusions, diseases)
        if pred_dise is None:
            na_cnt += 1

        # print(conclusions, goal, diseases)

        mrr, hits_1, hits_2 = get_MRR(conclusions, goal, diseases)

        # if hits_1 == 1:
        #     f.write(json.dumps(initial_state, ensure_ascii=False) + '\n')
        MRR.append(mrr)
        hits_1_cnt += hits_1
        correct_cnt += hits_1
        hits_2_cnt += hits_2

        # break

    acc_global = correct_cnt/len(data)
    acc_local = correct_cnt/(len(data) - na_cnt)
    hits_1_score = hits_1_cnt/len(data)
    hits_2_score = hits_2_cnt/len(data)

    print(f'{correct_cnt} of {len(data)} samples can be correctly diagnosed by CogInfer. Accuracy:{acc_global}')
    print(f'{na_cnt} of {len(data)} samples cannot be diagnosed (no applicable rule) by CogInfer. Coverage:{1 - na_cnt/len(data)}')
    print(f'{correct_cnt} of {len(data) - na_cnt} valid samples can be correctly diagnosed by CogInfer. Local Accuracy:{acc_local}')
    print(f'MRR: {np.mean(MRR)}')
    print(f'Acc_plus: {(correct_cnt + 0.5 * na_cnt)/len(data)}')
    print(f'Global Hits@1: {hits_1_score}')  
    print(f'Global Hits@2: {hits_2_score}') 
    print(f'F1 Score:{2 * acc_local * (1 - na_cnt/len(data)) / (acc_local + (1 - na_cnt/len(data)))}')

    return acc_global, acc_local, 1 - na_cnt/len(data), hits_1_score, hits_2_score, np.mean(MRR), 2 * acc_local * (1 - na_cnt/len(data)) / (acc_local + (1 - na_cnt/len(data)))


def main():
    CogKG_path = '/home/weizhepei/workspace/CogKG_english/'
    data_path = CogKG_path + 'data/diagnose/aligned/'
    rule_path = CogKG_path + 'data/rule/disease_rule/'
    KG_path = CogKG_path + "data/KG/"

    # k_fold_cross_val(data_path, rule_path, embeddings, ent2id, id2rel, id2rel)
    rule_dict = load_rules(rule_path)
    ent2rule = get_ent2rule(rule_dict)

    _, symptoms, diseases = load_data(data_path, split='train')

    HISTORY = {}     
    best_score = 0
    BEST_THRESHOLD = 0

    ent2id, id2ent, rel2id, id2rel, embeddings = load_KG(KG_path, embed_version=None)

    valid_data, _, _ = load_data(data_path, split='train + valid')
    # valid_data, _, _ = load_data(data_path, split='test') # For Coverage-Accuracy Curve

    print(f'\nPerformance on Valid Set with {len(valid_data)} Samples ...')
    for i in [-0.3, -0.2, -0.1, 0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]:
        print(f'\nPerformance on Train + Validation Set with threshold {i} ...')
        _, acc_local, coverage, hits_1, _, _, _ = report_metrcis(valid_data, embeddings, i, ent2id, ent2rule, rule_dict, id2rel, id2ent, diseases, noisy_or=False)

        HISTORY[i] = {}
        HISTORY[i]['accuracy'] = acc_local
        HISTORY[i]['coverage'] = coverage

        if hits_1 > best_score:
            best_score = hits_1
            BEST_THRESHOLD = i
    
    # with open(f'../CogInfer_Acc_Cov_Curve.json', 'w', encoding='utf-8') as f:
    #     f.write(json.dumps({str(i):j for i,j in HISTORY.items()}, ensure_ascii=False, indent=4))
    
    # BEST_THRESHOLD = 1.0
    print(f'BEST_THRESHOLD: {BEST_THRESHOLD}')
    ent2id, id2ent, rel2id, id2rel, embeddings = load_KG(KG_path, embed_version=None)
    valid_data, _, _ = load_data(data_path, split='test')
    _, acc_local, coverage, hits_1, hits_2, mrr, _ = report_metrcis(valid_data, embeddings, BEST_THRESHOLD, ent2id, ent2rule, rule_dict, id2rel, id2ent, diseases, noisy_or=False)

    PERFORMANCE = {'Accuracy':acc_local, 'Coverage':coverage, 'Hits@1':hits_1, 'Hits@2':hits_2, 'MRR':mrr}

    with open('../PERFORMANCE_CogInfer.json', 'w', encoding='utf-8') as f:
        f.write(json.dumps(PERFORMANCE, ensure_ascii=False, indent=4))



def main_incre():
    CogKG_path = '/home/weizhepei/workspace/CogKG/'
    data_path = CogKG_path + 'data/diagnose/aligned/'
    rule_path = CogKG_path + 'data/rule/disease_rule/'
    KG_path = CogKG_path + "data/KG/miniKG/"
    ent2id, id2ent, rel2id, id2rel, triples, embedings = load_KG(KG_path, embed_version=3000)

    _, _, diseases = load_data(data_path)

    acc_list = []
    hits_1_list, hits_2_list = [], []
    mrr_list = []
    valid_acc_list = []

    for dise in diseases.values():

        data, _, _ = load_data(data_path, IncreDise=dise)

        print(f'\nTest Incremental Setting with {len(data)} Samples of {dise}...')

        rule_dict = load_rules(rule_path, incre_dise=dise)
        ent2rule = get_ent2rule(rule_dict)

        correct_cnt = 0
        hits_1_cnt, hits_2_cnt = 0, 0
        na_cnt = 0
        MRR = []
        for item in tqdm(data):
            initial_state = item['symptoms']
            goal = item['disease']
            assert goal == dise

            conclusions = infer(initial_state, embedings, ent2id, ent2rule, rule_dict, id2rel, id2ent, do_sum=False, incre_setting=dise)

            pred_dise, pred_certainty = get_most_likely_disease(conclusions, diseases)

            if pred_dise is None:
                na_cnt += 1

            mrr, hits_1, hits_2 = get_MRR(conclusions, goal, diseases)

            MRR.append(mrr)
            hits_1_cnt += hits_1
            correct_cnt += hits_1
            hits_2_cnt += hits_2

        print(f'{correct_cnt} of {len(data)} samples can be correctly diagnosed by CogInfer. Accuracy:{correct_cnt/len(data)}')
        print(f'{na_cnt} of {len(data)} samples cannot be diagnosed (no applicable rule) by CogInfer. N/A Rate:{na_cnt/len(data)}')
        print(f'{correct_cnt} of {len(data) - na_cnt} valid samples can be correctly diagnosed by CogInfer. Accuracy:{correct_cnt/(len(data) - na_cnt)}')
        print(f'Global MRR: {np.mean(MRR)}')
        print(f'Global Hits@1: {hits_1_cnt/len(data)}')  
        print(f'Global Hits@2: {hits_2_cnt/len(data)}')  

        acc_list.append(correct_cnt/len(data))
        hits_1_list.append(hits_1_cnt/len(data))
        hits_2_list.append(hits_2_cnt/len(data))  
        mrr_list.append(np.mean(MRR))
        valid_acc_list.append(correct_cnt/(len(data) - na_cnt))
    
    avg_Acc = np.mean(acc_list)
    avg_Valid_Acc = np.mean(valid_acc_list)
    avg_MRR = np.mean(mrr_list)
    avg_Hits_1 = np.mean(hits_1_list)
    avg_Hits_2 = np.mean(hits_2_list)
    print(f"\nIncremental avg Acc:{avg_Acc}")
    print(f"Incremental avg Valid Acc:{avg_Valid_Acc}")
    print(f"Incremental avg MRR:{avg_MRR}")
    print(f"Incremental avg Global Hits@1:{avg_Hits_1}")
    print(f"Incremental avg Global Hits@2:{avg_Hits_2}")

if __name__ == '__main__':
    args = parser.parse_args()
    if args.setting == 'gene':
        main()
    elif args.setting == 'incre':
        main_incre()
    else:
        exit('Speficy the Setting!')
