import re, os, csv, json
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt
import pickle

def set_color(ReferGraph, V_knw, V_mem):
    color_map = []
    for node in ReferGraph:
        if node in V_knw:
            color_map.append('green')
        elif node in V_mem:
            color_map.append('yellow')
        else: color_map.append('lightblue') 
    return color_map

def draw_graph(G, color_map=None):
    # pos = nx.spring_layout(G)
    pos = nx.shell_layout(G)
    # pos = nx.circular_layout(G)
    pos = nx.drawing.nx_agraph.graphviz_layout(G, prog='dot')

    nx.draw(G, pos, node_color=color_map)
    node_labels = nx.get_node_attributes(G, 'desc')
    nx.draw_networkx_labels(G, pos, labels=node_labels)
    edge_labels = nx.get_edge_attributes(G, 'name')
    nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels)
    plt.show()

def load_KG(KG_data_path, embed_version=None):
    ent2id, id2ent = {}, {}
    with open(KG_data_path + 'entity2id.txt', 'r', encoding='utf-8') as f:
        for idx, item in enumerate(f.readlines()):
            if idx != 0:
                ent, id = item.split('\t')
                ent, id = ent.strip('\n'), int(id.strip('\n'))
                ent2id[ent] = id
                id2ent[id] = ent
 
    rel2id, id2rel = {}, {}
    with open(KG_data_path + 'relation2id.txt', 'r', encoding='utf-8') as f:
        for idx, item in enumerate(f.readlines()):
            if idx != 0:
                rel, id = item.split('\t')
                rel, id = rel.strip('\n'), int(id.strip('\n'))
                rel2id[rel] = int(id)
                id2rel[id] = rel

    # triples = set()
    # with open(KG_data_path + 'triples_all.txt', 'r', encoding='utf-8') as f:
    #     for i in f.readlines():
    #         sub, rel, obj = i.split('\t')
    #         sub, rel, obj = sub.strip('\n'), rel.strip('\n'), obj.strip('\n')
    #         triples.add((sub, obj, rel))

    embedings = pickle.load(open(KG_data_path + 'TransE.pkl', 'rb'))


    # embedings = {'zero_const':'...', 'pi_const':'...', 'ent_embeddings.weight':'...', 'rel_embeddings.weight':'...'}

    print(f'KG loaded with {len(ent2id)} entities,  {len(rel2id)} relations')
    return ent2id, id2ent, rel2id, id2rel, embedings

def get_ent2rule(RULES):
    ent2rule = dict()
    for r in RULES:
        for ent in RULES[r][0]:
            if ent not in ent2rule:
                ent2rule[ent] = set()
            ent2rule[ent].add(r)
    return ent2rule
    # np.save('/home/weizhepei/workspace/CogGraph/data/mini_CMeKG/ent2rule.npy', ent2rule)

def load_rules(rule_path,  K_fold=None, incre_dise=None):
    rule_dict = {}
    cnt = 0
    if K_fold is not None:
        rule_path += f'K-fold/fold-{K_fold}'
        
    for file in os.listdir(rule_path):
        csv_file = os.path.join(rule_path, file)
        if '.csv' in csv_file:
            if incre_dise is not None and incre_dise in csv_file:
                continue
            with open(csv_file, 'r') as f:
                reader = csv.reader(f)
                for item in reader:
                    if reader.line_num == 1:
                        continue
                    symtoms = re.findall(re.compile(r'[(](.*?)[)]',re.S), item[0])[0].split(',')
                    symtoms = [i.strip().strip("'") for i in symtoms]
                    # print(symtoms)
                    disease = re.findall('(?<=THEN).*$', item[0])[0].strip()
                    # print(disease)
                    rule_dict[f'r{cnt}'] = (symtoms, [disease], float(item[1]))
                    cnt += 1
    print(f'{len(rule_dict)} rules are added!')
    return rule_dict

def load_data(data_path, K_fold=None, IncreDise=None, change=None):
    data = []
    if K_fold is not None:
        if change is None:
            diag_file = data_path + f'K-fold/diag_valid_fold_{K_fold}.json'
        else:
            diag_file = data_path + f'K-fold/{change}/diag_{change}_fold_{K_fold}.json'     
    elif IncreDise is not None:
        diag_file = data_path + f'IncreSetting/diagnose_incre_{IncreDise}.json'
    else:
        diag_file = data_path + 'diagnose.json'
    with open(diag_file, 'r', encoding='utf-8') as f1:
        for line in f1.readlines():
            data.append(json.loads(line))

    with open(data_path + 'id2symptom.json', 'r', encoding='utf-8') as f2:
        symptoms = json.loads(f2.read())
    with open(data_path + 'id2disease.json', 'r', encoding='utf-8') as f3:
        disease = json.loads(f3.read())
    return data, symptoms, disease

def load_data_old(data_path, split, IncreDise=None):
    data = []  
    if IncreDise is not None:
        diag_file = data_path + f'IncreSetting/diagnose_incre_{IncreDise}.json'
    # elif split == 'valid':
    #     diag_file = data_path + 'diagnose_test.json'
    else:
        diag_file = data_path + f'diagnose_{split}.json'

    with open(diag_file, 'r', encoding='utf-8') as f1:
        for line in f1.readlines():
            data.append(json.loads(line))
        # if split != 'train':
        #     data_size = len(data)
        #     data = data[data_size//2:] if split == 'valid' else data[:data_size//2]

    with open(data_path + 'id2symptom.json', 'r', encoding='utf-8') as f2:
        symptoms = json.loads(f2.read())
    with open(data_path + 'id2disease.json', 'r', encoding='utf-8') as f3:
        disease = json.loads(f3.read())
    return data, symptoms, disease

def load_data(data_path, split, IncreDise=None):
    data = []  
    if IncreDise is not None:
        diag_file = data_path + f'IncreSetting/diagnose_incre_{IncreDise}.json'
    elif split == 'train + valid':
        diag_file = [data_path + 'diagnose_train.json', data_path + 'diagnose_valid.json']
    else:
        diag_file = data_path + f'diagnose_{split}.json'
    
    if split != 'train + valid':
        with open(diag_file, 'r', encoding='utf-8') as f1:
            for line in f1.readlines():
                data.append(json.loads(line))
    else:
        for i in diag_file:
            with open(i, 'r', encoding='utf-8') as f1:
                for line in f1.readlines():
                    data.append(json.loads(line))

    with open(data_path + 'id2symptom.json', 'r', encoding='utf-8') as f2:
        symptoms = json.loads(f2.read())
    with open(data_path + 'id2disease.json', 'r', encoding='utf-8') as f3:
        disease = json.loads(f3.read())
    return data, symptoms, disease

def FindRuleNodes(G, ent2rule):
    rule_node = set()
    for node in G.nodes():
        if node in ent2rule:
            for r in ent2rule[node]:
                rule_node.add(r)   
    return rule_node

def sum_prob(prob_list):
    final_prob = prob_list[0]
    if len(prob_list) > 1:
        for i in range(1, len(prob_list)):
            if final_prob > 0 and prob_list[i] > 0:
                final_prob = final_prob + prob_list[i] - final_prob * prob_list[i]
            elif final_prob < 0 and prob_list[i] < 0:
                final_prob = final_prob + prob_list[i] + final_prob * prob_list[i]
            else:
                final_prob = (final_prob + prob_list[i]) / (1 - min(abs(final_prob), abs(prob_list[i])))            
    return final_prob

def get_most_likely_disease(conclusions, diseases):
    dise_top = None
    certainty_top = 0
    for item, certainty in conclusions.items():
        if item in diseases.values():
            if certainty > certainty_top:
                certainty_top = certainty
                dise_top = item
    return dise_top, certainty_top

def get_MRR(conclusions, goal, goals):
    assert len(goals) == 12
    goals_dict = {}
    for item, certainty in conclusions.items():
        if item in goals.values():
            goals_dict[item] = certainty
    if goal not in goals_dict:
        hits_1, hits_2 = 0, 0
        return 1.0 / len(goals), hits_1, hits_2
    else:
        for i in goals.values():
            if i not in goals_dict:
                goals_dict[i] = 0

        target = goals_dict[goal]
        less_cnt = 0
        for k,v in goals_dict.items():
            if k != goal and v < target:
                less_cnt += 1
        rank = len(goals_dict) - less_cnt # worst case 
        assert rank >= 1 and rank <= len(goals_dict)

        if rank == 1:
            hits_1, hits_2 = 1, 1
        elif rank == 2:
            hits_1, hits_2 = 0, 1
        else:
            hits_1, hits_2 = 0, 0

        return 1.0/rank, hits_1, hits_2
