import numpy as np
import re,os,csv
import json
from tqdm import tqdm

import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--setting', choices=['gene', 'incre'])

"""
Mycin: a medical expert system.

This is a small example of an expert system that uses the
[Emycin](../../emycin.html) shell.  It defines a few contexts, parameters, and
rules, and presents a rudimentary user interface to collect data about an
infection in order to determine the identity of the infecting organism.

In a more polished system, we could:

- define and use a domain-specific language for the expert;
- present a more polished interface, perhaps a GUI, for user interaction;
- offer a data serialization mechanism to save state between sessions.

This implementation comes from chapter 16 of Peter Norvig's "Paradigms of
Artificial Intelligence Programming.

"""

### Utility functions

def eq(x, y):
    """Function for testing value equality."""
    return x == y

def boolean(string):
    """
    Function for reading True or False from a string.  Raises an error if the
    string is not True or False.
    """
    if string == 'True':
        return True
    if string == 'False':
        return False
    raise ValueError('bool must be True or False')

def load_data(data_path, K_fold=None, IncreDise=None, change=None):
    data = []
    if K_fold is not None:
        if change is None:
            diag_file = data_path + f'K-fold/diag_valid_fold_{K_fold}.json'
        else:
            diag_file = data_path + f'K-fold/{change}/diag_{change}_fold_{K_fold}.json'          
    elif IncreDise is not None:
        diag_file = data_path + f'IncreSetting/diagnose_incre_{IncreDise}.json'
    else:
        diag_file = data_path + 'diagnose.json'
    with open(diag_file, 'r', encoding='utf-8') as f1:
        for line in f1.readlines():
            data.append(json.loads(line))

    with open(data_path + 'id2symptom.json', 'r', encoding='utf-8') as f2:
        symptoms = json.loads(f2.read())
    with open(data_path + 'id2disease.json', 'r', encoding='utf-8') as f3:
        disease = json.loads(f3.read())
    return data, symptoms, disease

def load_rules(sh, rule_path, K_fold=None, incre_dise=None):
    cnt = 0
    if K_fold is not None:
        rule_path += f'K-fold/fold-{K_fold}'

    for file in os.listdir(rule_path):
        csv_file = os.path.join(rule_path, file)
        if '.csv' in csv_file:
            if incre_dise is not None and incre_dise in csv_file:
                continue
            with open(csv_file, 'r', encoding='utf-8') as f:
                reader = csv.reader(f)
                for item in reader:
                    if reader.line_num == 1:
                        continue
                    symtoms = re.findall(re.compile(r'[(](.*?)[)]',re.S), item[0])[0].split(',')
                    symtoms = [i.strip().strip("'") for i in symtoms]
                    disease = re.findall('(?<=THEN).*$', item[0])[0].strip()
                    premise = [(i, 'symptom', eq, True) for i in symtoms]
                    _rule = Rule(cnt, premise,[('identity', 'disease', eq, disease)],float(item[1]))
                    sh.define_rule(_rule)
                    cnt += 1

    
    print(f'{cnt} rules have been added in MYCIN.')

### Setting up initial data

# Here we define the contexts, parameters, and rules for our system.  This is
# the job of the expert, and in a more polished system, we would define and use
# a domain-specific language to make this easier.

def set_initial_states(sh, initial_states):
    sh.init_state(initial_states)

def define_contexts(sh):
    # define the context
    sh.define_context(Context('symptom')) 
    # Finding the identity of the disease is our goal.
    sh.define_context(Context('disease', goals=['identity']))

def define_params(sh, symptoms, diseases):
    # Disease params
    disease_list = list(diseases.values())
    sh.define_param(Parameter('identity', 'disease', enum=disease_list, ask_first=False))
    # Symptom params
    for i in list(symptoms.values()):
        sh.define_param(Parameter(i, 'symptom', cls=boolean, ask_first=False))
    
### Running the system

import logging
from emycin import Parameter, Context, Rule, Shell

def get_most_likely_disease(findings):
    dise_top = None
    certainty_top = 0
    for _, result in findings.items():
        for dise, certainty in result['identity'].items():
            if certainty > certainty_top:
                certainty_top = certainty
                dise_top = dise
        break
    return dise_top

def get_MRR(findings, goal, goals):
    assert len(goals) == 4

    goals_dict = findings

    if goal not in goals_dict:
        hits_1, hits_2 = 0, 0
        return 1.0 / len(goals), hits_1, hits_2
    else:
        for i in goals.values():
            if i not in goals_dict:
                goals_dict[i] = 0

        target = goals_dict[goal]
        less_cnt = 0
        for k,v in goals_dict.items():
            if k != goal and v < target:
                less_cnt += 1
        rank = len(goals_dict) - less_cnt # worst case 
        assert rank >= 1 and rank <= len(goals_dict)

        if rank == 1:
            hits_1, hits_2 = 1, 1
        elif rank == 2:
            hits_1, hits_2 = 0, 1
        else:
            hits_1, hits_2 = 0, 0

        # print(f'rank of goal: {rank}')
        # print(f'hits@1:{hits_1}, hits@2:{hits_2}')
        # print(f'reciprocal rank of goal: {1.0/rank}')
        return 1.0/rank, hits_1, hits_2


def main():
    logging.basicConfig(level=logging.INFO)
    CogKG_path = '/home/weizhepei/workspace/CogKG/'
    data_path = CogKG_path + 'data/diagnose/aligned/'
    rule_path = CogKG_path + 'data/rule/disease_rule/'

    acc_list = []
    hits_1_list, hits_2_list = [], []
    mrr_list = []
    valid_acc_list = []
    for i in range(10):
        data, symptoms, diseases = load_data(data_path, K_fold=i, change=None)
        print(f'\nTest Fold-{i} with {len(data)} Samples ...')
        # Get MYCIN Shell
        sh = Shell()
        define_contexts(sh)
        define_params(sh, symptoms, diseases)
        load_rules(sh, rule_path, K_fold=i)     # define_rules(sh)
    
        # Test MYCIN
        correct_cnt = 0
        hits_1_cnt, hits_2_cnt = 0, 0
        na_cnt = 0
        sample_id = 1
        MRR = []
        for item in tqdm(data):
            initial_state = item['symptoms']

            sh.set_init_state(initial_state)
            findings = sh.execute(['symptom', 'disease'])
            
            pred_dise = get_most_likely_disease(findings)

            if pred_dise is None:
                na_cnt += 1

            goals_dict = {}
            for _, result in findings.items():
                for dise, certainty in result['identity'].items():
                    goals_dict[dise] = certainty
                break
            

            mrr, hits_1, hits_2 = get_MRR(goals_dict, item["disease"], diseases)

            MRR.append(mrr)
            hits_1_cnt += hits_1
            correct_cnt += hits_1
            hits_2_cnt += hits_2
           
            sample_id += 1
        
        # print(f'hits_1_cnt:{hits_1_cnt}')
        # print(f'correct_cnt:{correct_cnt}')
        assert hits_1_cnt == correct_cnt

        print(f'{correct_cnt} of {len(data)} samples can be correctly diagnosed by MYCIN. Global Accuracy:{correct_cnt/len(data)}')
        print(f'{na_cnt} of {len(data)} samples cannot be diagnosed (no applicable rule) by MYCIN. N/A Rate:{na_cnt/len(data)}')
        print(f'{correct_cnt} of {len(data) - na_cnt} valid samples can be correctly diagnosed by MYCIN. Local Accuracy:{correct_cnt/(len(data) - na_cnt)}')
        print(f'Global MRR: {np.mean(MRR)}')
        print(f'Global Hits@1: {hits_1_cnt/len(data)}')  
        print(f'Global Hits@2: {hits_2_cnt/len(data)}')  

   
        acc_list.append(correct_cnt/len(data))
        hits_1_list.append(hits_1_cnt/len(data))
        hits_2_list.append(hits_2_cnt/len(data))       
        mrr_list.append(np.mean(MRR))
        valid_acc_list.append(correct_cnt/(len(data) - na_cnt))
    
    avg_Acc = np.mean(acc_list)
    avg_Valid_Acc = np.mean(valid_acc_list)
    avg_MRR = np.mean(mrr_list)
    avg_Hits_1 = np.mean(hits_1_list)
    avg_Hits_2 = np.mean(hits_2_list)
    print(f"\n10-Fold avg Global Acc:{avg_Acc}")
    print(f"10-Fold avg Local Acc:{avg_Valid_Acc}")
    print(f"10-Fold avg Global MRR:{avg_MRR}")
    print(f"10-Fold avg Global Hits@1:{avg_Hits_1}")
    print(f"10-Fold avg Global Hits@2:{avg_Hits_2}")


def main_incre():

    CogKG_path = '/home/weizhepei/workspace/CogKG/'
    data_path = CogKG_path + 'data/diagnose/aligned/'
    rule_path = CogKG_path + 'data/rule/disease_rule/'
    KG_path = CogKG_path + "data/KG/miniKG/"
    ent2id, id2ent, rel2id, id2rel, triples, embeddings = load_KG(KG_path)

    _, _, diseases = load_data(data_path)

    acc_list = []
    hits_1_list, hits_2_list = [], []
    mrr_list = []
    valid_acc_list = []

    for dise in diseases.values():
        
        data, symptoms, diseases = load_data(data_path, IncreDise=dise)

        print(f'\nTest Incremental Setting with {len(data)} Samples of {dise}...')

        # Get MYCIN Shell
        sh = Shell()
        define_contexts(sh)
        define_params(sh, symptoms, diseases)
        load_rules(sh, rule_path, incre_dise=dise)     # define_rules(sh)

        correct_cnt = 0
        hits_1_cnt, hits_2_cnt = 0, 0
        na_cnt = 0
        MRR = []
        for item in tqdm(data):
            initial_state = item['symptoms']
            sh.set_init_state(initial_state)
            findings = sh.execute(['symptom', 'disease'])
            
            pred_dise = get_most_likely_disease(findings)

            if pred_dise is None:
                na_cnt += 1
            
            goals_dict = {}
            for _, result in findings.items():
                for dise, certainty in result['identity'].items():
                    goals_dict[dise] = certainty
                break

            goals_dict = apply_link_prediction(item["disease"], goals_dict, ent2id, id2rel, id2ent, embeddings, do_sum=True, show_infer_step=False)

            mrr, hits_1, hits_2 = get_MRR(goals_dict, item["disease"], diseases)

            MRR.append(mrr)
            hits_1_cnt += hits_1
            correct_cnt += hits_1
            hits_2_cnt += hits_2

        print(f'{correct_cnt} of {len(data)} samples can be correctly diagnosed by MYCIN. Accuracy:{correct_cnt/len(data)}')
        print(f'{na_cnt} of {len(data)} samples cannot be diagnosed (no applicable rule) by MYCIN. N/A Rate:{na_cnt/len(data)}')
        print(f'{correct_cnt} of {len(data) - na_cnt} valid samples can be correctly diagnosed by MYCIN. Accuracy:{correct_cnt/(len(data) - na_cnt)}')
        print(f'Global MRR: {np.mean(MRR)}')
        print(f'Global Hits@1: {hits_1_cnt/len(data)}')  
        print(f'Global Hits@2: {hits_2_cnt/len(data)}')  

        acc_list.append(correct_cnt/len(data))
        hits_1_list.append(hits_1_cnt/len(data))
        hits_2_list.append(hits_2_cnt/len(data))  
        mrr_list.append(np.mean(MRR))
        valid_acc_list.append(correct_cnt/(len(data) - na_cnt))
    
    avg_Acc = np.mean(acc_list)
    avg_Valid_Acc = np.mean(valid_acc_list)
    avg_MRR = np.mean(mrr_list)
    avg_Hits_1 = np.mean(hits_1_list)
    avg_Hits_2 = np.mean(hits_2_list)
    print(f"\nIncremental avg Acc:{avg_Acc}")
    print(f"Incremental avg Valid Acc:{avg_Valid_Acc}")
    print(f"Incremental avg MRR:{avg_MRR}")
    print(f"Incremental avg Global Hits@1:{avg_Hits_1}")
    print(f"Incremental avg Global Hits@2:{avg_Hits_2}")

def load_KG(KG_data_path):
    ent2id, id2ent = {}, {}

    with open(KG_data_path + 'entity2id.txt', 'r', encoding='utf-8') as f:
        for idx, item in enumerate(f.readlines()):
            if idx != 0:
                ent, id = item.split('\t')
                ent, id = ent.strip('\n'), int(id.strip('\n'))
                ent2id[ent] = id
                id2ent[id] = ent

    rel2id, id2rel = {}, {}
    with open(KG_data_path + 'relation2id.txt', 'r', encoding='utf-8') as f:
        for idx, item in enumerate(f.readlines()):
            if idx != 0:
                rel, id = item.split('\t')
                rel, id = rel.strip('\n'), int(id.strip('\n'))
                rel2id[rel] = int(id)
                id2rel[id] = rel

    triples = set()
    with open(KG_data_path + 'triples_all.txt', 'r', encoding='utf-8') as f:
        for i in f.readlines():
            sub, rel, obj = i.split('\t')
            sub, rel, obj = sub.strip('\n'), rel.strip('\n'), obj.strip('\n')
            triples.add((sub, obj, rel))

    embedings = json.load(open(KG_data_path + 'embed.vec', 'r'))

    print(f'KG laoded with {len(triples)} triples, {len(ent2id)} entities,  {len(rel2id)} relations')
    return ent2id, id2ent, rel2id, id2rel, triples, embedings

def apply_link_prediction(incre_setting, known_nodes, ent2id, id2rel, id2ent, embedings, do_sum, show_infer_step=False):
    id2embed_ent = embedings['ent_embeddings.weight']
    id2embed_rel = embedings['rel_embeddings.weight']

    incre_node = ent2id[incre_setting]

    V_knw_dises = {ent2id[k]:v for k,v in known_nodes.items()}

    isTrue, prob = check_triple(incre_node, V_knw_dises, id2embed_ent, id2embed_rel, id2rel, id2ent, do_sum=do_sum, show_infer_step=show_infer_step, threshold=0, incre=True)

    if isTrue:
        known_nodes[incre_setting] = prob
        if show_infer_step:
            print(f'Link Prediction Applied! Add Node {incre_setting} with prob {prob}.')
    
    return known_nodes

def sum_prob(prob_list):
    '''xor: a + b - a * b '''
    final_prob = prob_list[0]
    if len(prob_list) > 1:
        for i in range(1, len(prob_list)):
            final_prob = final_prob + prob_list[i] - final_prob * prob_list[i]
    return final_prob

def check_triple(in_no, V_knw, id2embed_ent, id2embed_rel, id2rel, id2ent, do_sum=False, show_infer_step=False, threshold=0.5, incre=False):
    V_knw_tmp = V_knw.copy()
    in_no_prob = []
    for knw in V_knw_tmp:
        if V_knw_tmp[knw] > 0:
            for idx_rel, vec_rel in enumerate(id2embed_rel):
                if incre and id2rel[idx_rel] not in ['并发症', '病因', '相关疾病']:
                    continue
                else:
                    vec_sub = id2embed_ent[knw]
                    vec1 = np.add(vec_sub, vec_rel)
                    vec2 = id2embed_ent[in_no]
                
                    CosSim = float(np.dot(vec1,vec2)/(np.linalg.norm(vec1)*np.linalg.norm(vec2))) * V_knw_tmp[knw]
                    if (CosSim >= threshold):
                        if show_infer_step:
                            print(f'{id2ent[knw]} + {id2rel[idx_rel]} -> {id2ent[in_no]}')
                        in_no_prob.append(CosSim)

    if len(in_no_prob) > 0:
        return True, sum_prob(in_no_prob) if do_sum else max(in_no_prob)
    else:
        return False, 0

if __name__ == '__main__':
    args = parser.parse_args()
    if args.setting == 'gene':
        main()
    elif args.setting == 'incre':
        main_incre()
    else:
        exit('Speficy the Setting!')