from tqdm import tqdm
import numpy as np
import json
import random

def get_rank_score(conclusions, goal, goals):
    assert len(goals) == 4
    goals_dict = {}
    for item, certainty in conclusions.items():
        if item in goals:
            goals_dict[item] = certainty
    if goal not in goals_dict:
        hits_1, hits_2 = 0, 0
        return 1.0 / len(goals), hits_1, hits_2
    else:
        for i in goals:
            if i not in goals_dict:
                goals_dict[i] = 0

        target = goals_dict[goal]
        less_cnt = 0
        for k,v in goals_dict.items():
            if k != goal and v < target:
                less_cnt += 1
        rank = len(goals_dict) - less_cnt # worst case 
        assert rank >= 1 and rank <= len(goals_dict)

        if rank == 1:
            hits_1, hits_2 = 1, 1
        elif rank == 2:
            hits_1, hits_2 = 0, 1
        else:
            hits_1, hits_2 = 0, 0

        return 1.0/rank, hits_1, hits_2

def load_KG(KG_data_path, embed_version=None):
    ent2id, id2ent = {}, {}
    with open(KG_data_path + 'entity2id.txt', 'r', encoding='utf-8') as f:
        for idx, item in enumerate(f.readlines()):
            if idx != 0:
                ent, id = item.split('\t')
                ent, id = ent.strip('\n'), int(id.strip('\n'))
                ent2id[ent] = id
                id2ent[id] = ent
 
    rel2id, id2rel = {}, {}
    with open(KG_data_path + 'relation2id.txt', 'r', encoding='utf-8') as f:
        for idx, item in enumerate(f.readlines()):
            if idx != 0:
                rel, id = item.split('\t')
                rel, id = rel.strip('\n'), int(id.strip('\n'))
                rel2id[rel] = int(id)
                id2rel[id] = rel

    triples = set()
    with open(KG_data_path + 'triples_all.txt', 'r', encoding='utf-8') as f:
        for i in f.readlines():
            sub, rel, obj = i.split('\t')
            sub, rel, obj = sub.strip('\n'), rel.strip('\n'), obj.strip('\n')
            triples.add((sub, obj, rel))
    if embed_version is None:
        embedings = json.load(open(KG_data_path + 'embed.vec', 'r'))
    else:
        embedings = json.load(open(KG_data_path + f'embed_{embed_version}.vec', 'r'))

    print(f'KG loaded with {len(triples)} triples, {len(ent2id)} entities,  {len(rel2id)} relations')
    return ent2id, id2ent, rel2id, id2rel, triples, embedings

def do_noisy_or(prob_list):
    final_prob = prob_list[0]
    if len(prob_list) > 1:
        for i in range(1, len(prob_list)):
            if final_prob > 0 and prob_list[i] > 0:
                final_prob = final_prob + prob_list[i] - final_prob * prob_list[i]
            elif final_prob < 0 and prob_list[i] < 0:
                final_prob = final_prob + prob_list[i] + final_prob * prob_list[i]
            else:
                final_prob = (final_prob + prob_list[i]) / (1 - min(abs(final_prob), abs(prob_list[i])))            
    return final_prob

def check_triple(in_no, V_knw, id2embed_ent, id2embed_rel, id2rel, id2ent, threshold, noisy_or=False, show_infer_step=False, incre=False):
    V_knw_tmp = V_knw.copy()
    in_no_prob = []
    for knw in V_knw_tmp:
        if V_knw_tmp[knw] > 0:
            for idx_rel, vec_rel in enumerate(id2embed_rel):
                if incre and id2rel[idx_rel] not in ['并发症', '病因', '相关疾病']:
                    continue
                elif id2rel[idx_rel] not in ['并发症', '病因', '相关疾病']:
                    continue
                else:
                    vec_sub = id2embed_ent[knw]
                    vec1 = np.add(vec_sub, vec_rel)
                    vec2 = id2embed_ent[in_no]
                
                    CosSim = float(np.dot(vec1,vec2)/(np.linalg.norm(vec1)*np.linalg.norm(vec2))) * V_knw_tmp[knw]
                    if (CosSim >= threshold):
                        if show_infer_step:
                            print(f'{id2ent[knw]} + {id2rel[idx_rel]} -> {id2ent[in_no]}')
                        in_no_prob.append(CosSim)

    if len(in_no_prob) > 0:
        return True, do_noisy_or(in_no_prob) if noisy_or else max(in_no_prob)
    else:
        return False, 0


def load_data(data_path, split, IncreDise=None):
    data = []  
    if IncreDise is not None:
        diag_file = data_path + f'IncreSetting/diagnose_incre_{IncreDise}.json'
    elif split == 'train + valid':
        diag_file = [data_path + 'diagnose_train.json', data_path + 'diagnose_valid.json']
    else:
        diag_file = data_path + f'diagnose_{split}.json'
    
    if split != 'train + valid':
        with open(diag_file, 'r', encoding='utf-8') as f1:
            for line in f1.readlines():
                data.append(json.loads(line))
    else:
        for i in diag_file:
            with open(i, 'r', encoding='utf-8') as f1:
                for line in f1.readlines():
                    data.append(json.loads(line))

    with open(data_path + 'id2symptom.json', 'r', encoding='utf-8') as f2:
        symptoms = json.loads(f2.read())
    with open(data_path + 'id2disease.json', 'r', encoding='utf-8') as f3:
        disease = json.loads(f3.read())
    return data, symptoms, disease

def infer(initial_state, embeddings, ent2id, id2rel, id2ent, noisy_or, threshold, diseases, show_infer_step=False):

    V_knw = {}
    for i in initial_state:
        V_knw[ent2id[i]] = initial_state[i]
    
    V_knw_prob ={i:{'True':1.0, 'False':-1.0}[V_knw[i]] for i in V_knw}

    goals_dict = {}
    for goal in diseases:

        # print(f'goal:{goal}')
        isTrue, prob = check_triple(ent2id[goal], V_knw_prob, embeddings['ent_embeddings.weight'], embeddings['rel_embeddings.weight'], id2rel, id2ent, threshold, noisy_or=noisy_or, show_infer_step=show_infer_step, incre=False)
        
        goals_dict[goal] = prob

        if show_infer_step:
            print(f'Link Prediction:{isTrue}')
            print(f'Add Node {goal} with prob {prob}.')

    return goals_dict


def report_metrics(data, embeddings, ent2id, id2rel, id2ent, diseases, noisy_or, threshold):
    correct_cnt = 0
    hits_1_cnt, hits_2_cnt = 0, 0
    na_cnt = 0
    MRR = []
    for item in tqdm(data):
        initial_state = item['symptoms']
        goal = item['disease']

        conclusions = infer(initial_state, embeddings, ent2id, id2rel, id2ent, noisy_or, threshold, diseases, show_infer_step=False)

        mrr, hits_1, hits_2 = get_rank_score(conclusions, goal, diseases)

        MRR.append(mrr)
        hits_1_cnt += hits_1
        correct_cnt += hits_1
        hits_2_cnt += hits_2


    coverage = 1 - na_cnt/len(data)
    accuracy = correct_cnt/(len(data) - na_cnt)
    accuracy_plus = (correct_cnt + na_cnt * 0.25) / len(data)
    f1_score = 2 * accuracy * coverage / (accuracy + coverage)
    hits_1_score = hits_1_cnt/len(data)
    hits_2_score = hits_2_cnt/len(data)
    mrr = np.mean(MRR)

    print(f'Coverage:{coverage}; ({len(data) - na_cnt} of {len(data)} samples)')
    print(f'Accuracy:{accuracy}; ({correct_cnt} of {len(data) - na_cnt} samples)')
    print(f'Acc_plus: {accuracy_plus}')
    print(f'F1 Score: {f1_score}')  
    print(f'Hits@1: {hits_1_score}')  
    print(f'Hits@2: {hits_2_score}')  
    print(f'MRR: {mrr}')

    return coverage, accuracy, accuracy_plus, f1_score, hits_1_score, hits_2_score, mrr

def main():
    CogKG_path = '/home/weizhepei/workspace/CogKG/'
    data_path = CogKG_path + 'data/diagnose/aligned/'
    KG_path = CogKG_path + "data/KG/miniKG/"
    
    _, symptoms, diseases = load_data(data_path, split='train')

    HISTORY = {}
    best_perform = 0
    BEST_THRESHOLD = 0
    BEST_EMBEDDING = 0
    for ver in range(1000, 11000, 1000):
        # ver = 2000
        ent2id, id2ent, rel2id, id2rel, triples, embeddings = load_KG(KG_path, embed_version=ver)

        show_statistics = True
        dist_list = []
        for (sub, obj, rel) in triples:
            
            vec_sub = embeddings['ent_embeddings.weight'][ent2id[sub]]
            vec_rel = embeddings['rel_embeddings.weight'][rel2id[rel]]

            vec1 = np.add(vec_sub, vec_rel)
            vec2 = embeddings['ent_embeddings.weight'][ent2id[obj]]
            
            dist = float(np.dot(vec1,vec2)/(np.linalg.norm(vec1)*np.linalg.norm(vec2)))
            dist_list.append(dist)
        
        min_cos_sim = np.min(dist_list)
        max_cos_sim = np.max(dist_list)
        if show_statistics:
            # 统计gold triple (h, r, t)中 (h+r)与 t 的平均余弦相似度
            print(f'统计gold triple (h, r, t)中 (h+r)与 t 的平均余弦相似度')
            print(f'min cosine dist:{min_cos_sim}') 
            print(f'avg cosine dist:{np.mean(dist_list)}') 
            print(f'median cosine dist:{np.median(dist_list)}')
            print(f'max cosine dist:{max_cos_sim}') 

        SPLIT = 'train + valid' # 'test' for Acc-Cov Curve, 'train + valid' for grind search
        valid_data, _, _ = load_data(data_path, split=SPLIT)

        for i in [-0.6, -0.5, -0.4, -0.3, -0.2, -0.1, 0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]:
        # STEP = 0.1
        # for i in np.arange(-0.6, 0.8 + STEP, STEP):
            print(f'\nPerformance on {SPLIT} Set with KG embeddings {ver}, threshold {i} ...')
            coverage, accuracy, _, _, hits_1_score, _, _ = report_metrics(valid_data, embeddings, ent2id, id2rel, id2ent, list(diseases.values()), noisy_or=True, threshold=i)
            
            if hits_1_score > best_perform:
                BEST_THRESHOLD = i
                BEST_EMBEDDING = ver
                # best_perform = f1_score
                best_perform = hits_1_score

            HISTORY[(ver,i)] = {}
            HISTORY[(ver,i)]['accuracy'] = accuracy
            HISTORY[(ver,i)]['coverage'] = coverage
        
        # break
    
    # with open(f'../Purelink_Acc_Cov_Curve_{ver}.json', 'w', encoding='utf-8') as f:
    #     f.write(json.dumps({str(i):j for i,j in HISTORY.items()}, ensure_ascii=False, indent=4))
    
    # BEST_EMBEDDING = 4000
    # BEST_THRESHOLD = 0.2
    print(f'\nPerformance on Test Set with KG embeddings {BEST_EMBEDDING}, threshold {BEST_THRESHOLD}.')
    ent2id, id2ent, rel2id, id2rel, triples, embeddings = load_KG(KG_path, embed_version=BEST_EMBEDDING)
    test_data, _, _ = load_data(data_path, split='test')

    coverage, accuracy, accuracy_plus, f1_score, hits_1_score, hits_2_score, mrr = report_metrics(test_data, embeddings, ent2id, id2rel, id2ent, list(diseases.values()), noisy_or=True, threshold=BEST_THRESHOLD)

    PERFORMANCE = {'Coverage':coverage, 'Accuracy':accuracy, 'Accuracy_plus':accuracy_plus, 'F1_score':f1_score, 'Hits@1':hits_1_score, 'Hits@2':hits_2_score, 'MRR':mrr}
    with open('../PERFORMANCE_PURELINK.json', 'w', encoding='utf-8') as f:
        f.write(json.dumps(PERFORMANCE, ensure_ascii=False, indent=4))
    


if __name__ == '__main__':
    main()