import numpy as np
import re,os,csv
import json
from tqdm import tqdm

import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--setting', choices=['gene', 'incre'])

"""
Mycin: a medical expert system.

This is a small example of an expert system that uses the
[Emycin](../../emycin.html) shell.  It defines a few contexts, parameters, and
rules, and presents a rudimentary user interface to collect data about an
infection in order to determine the identity of the infecting organism.

In a more polished system, we could:

- define and use a domain-specific language for the expert;
- present a more polished interface, perhaps a GUI, for user interaction;
- offer a data serialization mechanism to save state between sessions.

This implementation comes from chapter 16 of Peter Norvig's "Paradigms of
Artificial Intelligence Programming.

"""

### Utility functions

def eq(x, y):
    """Function for testing value equality."""
    return x == y

def boolean(string):
    """
    Function for reading True or False from a string.  Raises an error if the
    string is not True or False.
    """
    if string == 'True':
        return True
    if string == 'False':
        return False
    raise ValueError('bool must be True or False')

def load_data(data_path, split, IncreDise=None):
    data = []  
    if IncreDise is not None:
        diag_file = data_path + f'IncreSetting/diagnose_incre_{IncreDise}.json'
    else:
        diag_file = data_path + f'diagnose_{split}.json'

    with open(diag_file, 'r', encoding='utf-8') as f1:
        for line in f1.readlines():
            data.append(json.loads(line))

    with open(data_path + 'id2symptom.json', 'r', encoding='utf-8') as f2:
        symptoms = json.loads(f2.read())
    with open(data_path + 'id2disease.json', 'r', encoding='utf-8') as f3:
        disease = json.loads(f3.read())
    return data, symptoms, disease

def load_rules(sh, rule_path, K_fold=None, incre_dise=None):
    cnt = 0
    if K_fold is not None:
        rule_path += f'K-fold/fold-{K_fold}'

    for file in os.listdir(rule_path):
        csv_file = os.path.join(rule_path, file)
        if '.csv' in csv_file:
            if incre_dise is not None and incre_dise in csv_file:
                continue
            with open(csv_file, 'r', encoding='utf-8') as f:
                reader = csv.reader(f)
                for item in reader:
                    if reader.line_num == 1:
                        continue
                    symtoms = re.findall(re.compile(r'[(](.*?)[)]',re.S), item[0])[0].split(',')
                    symtoms = [i.strip().strip("'") for i in symtoms]
                    disease = re.findall('(?<=THEN).*$', item[0])[0].strip()
                    premise = [(i, 'symptom', eq, True) for i in symtoms]
                    _rule = Rule(cnt, premise,[('identity', 'disease', eq, disease)],float(item[1]))
                    sh.define_rule(_rule)
                    cnt += 1

    
    print(f'{cnt} rules have been added in MYCIN.')

### Setting up initial data

# Here we define the contexts, parameters, and rules for our system.  This is
# the job of the expert, and in a more polished system, we would define and use
# a domain-specific language to make this easier.

def set_initial_states(sh, initial_states):
    sh.init_state(initial_states)

def define_contexts(sh):
    # define the context
    sh.define_context(Context('symptom')) 
    # Finding the identity of the disease is our goal.
    sh.define_context(Context('disease', goals=['identity']))

def define_params(sh, symptoms, diseases):
    # Disease params
    disease_list = list(diseases.values())
    sh.define_param(Parameter('identity', 'disease', enum=disease_list, ask_first=False))
    # Symptom params
    for i in list(symptoms.values()):
        sh.define_param(Parameter(i, 'symptom', cls=boolean, ask_first=False))
    
### Running the system

import logging
from emycin import Parameter, Context, Rule, Shell, CF

def get_most_likely_disease(findings):
    dise_top = None
    certainty_top = 0
    for _, result in findings.items():
        for dise, certainty in result['identity'].items():
            if certainty > certainty_top:
                certainty_top = certainty
                dise_top = dise
        break
    return dise_top

def get_MRR(findings, goal, goals):
    assert len(goals) == 4

    goals_dict = findings

    if goal not in goals_dict:
        hits_1, hits_2 = 0, 0
        return 1.0 / len(goals), hits_1, hits_2
    else:
        for i in goals.values():
            if i not in goals_dict:
                goals_dict[i] = 0

        target = goals_dict[goal]
        less_cnt = 0
        for k,v in goals_dict.items():
            if k != goal and v < target:
                less_cnt += 1
        rank = len(goals_dict) - less_cnt # worst case 
        assert rank >= 1 and rank <= len(goals_dict)

        if rank == 1:
            hits_1, hits_2 = 1, 1
        elif rank == 2:
            hits_1, hits_2 = 0, 1
        else:
            hits_1, hits_2 = 0, 0

        return 1.0/rank, hits_1, hits_2    

def k_fold_corss_val(data_path, rule_path):

    acc_list = []
    hits_1_list, hits_2_list = [], []
    mrr_list = []
    valid_acc_list = []

    for i in range(10):
        data, symptoms, diseases = load_data(data_path, K_fold=i, change=None)
        print(f'\nTest Fold-{i} with {len(data)} Samples ...')
        # Get MYCIN Shell
        sh = Shell()
        define_contexts(sh)
        define_params(sh, symptoms, diseases)
        load_rules(sh, rule_path, K_fold=i)     # define_rules(sh)
    
        # Test MYCIN
        correct_cnt = 0
        hits_1_cnt, hits_2_cnt = 0, 0
        na_cnt = 0
        sample_id = 1
        MRR = []
        for item in tqdm(data):
            initial_state = item['symptoms']

            sh.set_init_state(initial_state)
            findings = sh.execute(['symptom', 'disease'])
            
            pred_dise = get_most_likely_disease(findings)

            if pred_dise is None:
                na_cnt += 1

            goals_dict = {}
            for _, result in findings.items():
                for dise, certainty in result['identity'].items():
                    goals_dict[dise] = certainty
                break
            
            mrr, hits_1, hits_2 = get_MRR(goals_dict, item["disease"], diseases)

            MRR.append(mrr)
            hits_1_cnt += hits_1
            correct_cnt += hits_1
            hits_2_cnt += hits_2
           
            sample_id += 1
        
        # print(f'hits_1_cnt:{hits_1_cnt}')
        # print(f'correct_cnt:{correct_cnt}')
        assert hits_1_cnt == correct_cnt

        print(f'{correct_cnt} of {len(data)} samples can be correctly diagnosed by MYCIN. Global Accuracy:{correct_cnt/len(data)}')
        print(f'{na_cnt} of {len(data)} samples cannot be diagnosed (no applicable rule) by MYCIN. N/A Rate:{na_cnt/len(data)}')
        print(f'{correct_cnt} of {len(data) - na_cnt} valid samples can be correctly diagnosed by MYCIN. Local Accuracy:{correct_cnt/(len(data) - na_cnt)}')
        print(f'Global MRR: {np.mean(MRR)}')
        print(f'Global Hits@1: {hits_1_cnt/len(data)}')  
        print(f'Global Hits@2: {hits_2_cnt/len(data)}')  

   
        acc_list.append(correct_cnt/len(data))
        hits_1_list.append(hits_1_cnt/len(data))
        hits_2_list.append(hits_2_cnt/len(data))       
        mrr_list.append(np.mean(MRR))
        valid_acc_list.append(correct_cnt/(len(data) - na_cnt))
    
    avg_Acc = np.mean(acc_list)
    avg_Valid_Acc = np.mean(valid_acc_list)
    avg_MRR = np.mean(mrr_list)
    avg_Hits_1 = np.mean(hits_1_list)
    avg_Hits_2 = np.mean(hits_2_list)
    print(f"\n10-Fold avg Global Acc:{avg_Acc}")
    print(f"10-Fold avg Local Acc:{avg_Valid_Acc}")
    print(f"10-Fold avg Global MRR:{avg_MRR}")
    print(f"10-Fold avg Global Hits@1:{avg_Hits_1}")
    print(f"10-Fold avg Global Hits@2:{avg_Hits_2}")


def report_metrics(shell, data, diseases):
    # Test MYCIN
    correct_cnt = 0
    hits_1_cnt, hits_2_cnt = 0, 0
    na_cnt = 0
    MRR = []
    for item in tqdm(data):
        initial_state = item['symptoms']
        ground_truth = item["disease"]

        shell.set_init_state(initial_state)
        findings = shell.execute(['symptom', 'disease'])

        pred_dise = get_most_likely_disease(findings)

        if pred_dise is None:
            na_cnt += 1

        goals_dict = {}
        for _, result in findings.items():
            for dise, certainty in result['identity'].items():
                goals_dict[dise] = certainty
            break

        mrr, hits_1, hits_2 = get_MRR(goals_dict, ground_truth, diseases)

        MRR.append(mrr)
        hits_1_cnt += hits_1
        correct_cnt += hits_1
        hits_2_cnt += hits_2


        assert hits_1_cnt == correct_cnt
    
    coverage = 1 - na_cnt/len(data)
    accuracy = correct_cnt/(len(data) - na_cnt) if (len(data) - na_cnt) != 0 else 0
    accuracy_plus = (correct_cnt + na_cnt * 0.25) / len(data)
    f1_score = 2 * accuracy * coverage / (accuracy + coverage) if (accuracy + coverage) != 0 else 0
    hits_1_score = hits_1_cnt/len(data)
    hits_2_score = hits_2_cnt/len(data)
    mrr = np.mean(MRR)

    print(f'Coverage:{coverage}; ({len(data) - na_cnt} of {len(data)} samples)')
    print(f'Accuracy:{accuracy}; ({correct_cnt} of {len(data) - na_cnt} samples)')
    print(f'Acc_plus: {accuracy_plus}')
    print(f'F1 Score: {f1_score}')  
    print(f'Hits@1: {hits_1_score}')  
    print(f'Hits@2: {hits_2_score}')  
    print(f'MRR: {mrr}')

    return coverage, accuracy, accuracy_plus, f1_score, hits_1_score, hits_2_score, mrr


def main():
    # f = open('tmp', 'w+')

    logging.basicConfig(level=logging.INFO)
    CogKG_path = '/home/weizhepei/workspace/CogKG/'
    data_path = CogKG_path + 'data/diagnose/aligned/'
    rule_path = CogKG_path + 'data/rule/disease_rule/'

    # k_fold_corss_val(data_path, rule_path)
    _, symptoms, diseases = load_data(data_path, split='train')
    # Get MYCIN Shell
    sh = Shell()
    define_contexts(sh)
    define_params(sh, symptoms, diseases)
    load_rules(sh, rule_path, K_fold=None)     # define_rules(sh)

    valid_data, _, _ = load_data(data_path, split='valid')
    print(f'\nPerformance on Valid set with {len(valid_data)} Samples ...')
    BEST_CUTOFF = 0
    best_f1 = 0
    for cutoff in [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6 ,0.7, 0.8, 0.9, 1.0]:
        sh.set_cutoff(cutoff)
        coverage, accuracy, accuracy_plus, f1_score, hits_1_score, hits_2_score, mrr = report_metrics(sh, valid_data, diseases)
        if f1_score > best_f1:
            best_f1 = f1_score
            BEST_CUTOFF = cutoff
    
    test_data, _, _ = load_data(data_path, split='test')
    print(f'\nPerformance on Test set with {len(test_data)} Samples, Cutoff {BEST_CUTOFF} ...')
    sh.set_cutoff(BEST_CUTOFF)
    coverage, accuracy, accuracy_plus, f1_score, hits_1_score, hits_2_score, mrr = report_metrics(sh, test_data, diseases, )
    PERFORMANCE = {'Coverage':coverage, 'Accuracy':accuracy, 'Accuracy_plus':accuracy_plus, 'F1_score':f1_score, 'Hits@1':hits_1_score, 'Hits@2':hits_2_score, 'MRR':mrr}

    with open('../../PERFORMANCE_MYCIN.json', 'w', encoding='utf-8') as f:
        f.write(json.dumps(PERFORMANCE, ensure_ascii=False, indent=4))

if __name__ == '__main__':
    main()