import json
from tqdm import tqdm
import numpy as np
from interpret.glassbox import ExplainableBoostingClassifier
from sklearn.metrics import f1_score, precision_score, recall_score

def load_data(data_path, feature_type='scalar'):    
    with open(data_path + 'symptom2id.json', 'r', encoding='utf-8') as f:
        symptom2id = json.loads(f.read())

    with open(data_path + 'disease2id.json', 'r', encoding='utf-8') as f:
        disease2id = json.loads(f.read())

    with open(data_path + 'id2symptom.json', 'r', encoding='utf-8') as f:
        id2symptom = json.loads(f.read())

    with open(data_path + 'id2disease.json', 'r', encoding='utf-8') as f:
        id2disease = json.loads(f.read())   
    
    features_train, labels_train = [], []

    with open(data_path + f'diagnose_train.json', 'r', encoding='utf-8') as f:
        for line in f.readlines():
            line = json.loads(line)
            feature = [-1] * len(symptom2id) if feature_type == 'scalar' else [[0,0]] * len(symptom2id)
            # label = int(disease2id[line['disease']])
            label = line['disease']

            for k, v in line['symptoms'].items():
                if v == 'True':
                    feature[symptom2id[k]] = 1 if feature_type == 'scalar' else [1,0]
                elif v == 'False':
                    feature[symptom2id[k]] = 0 if feature_type == 'scalar' else [0,1]
                else:
                    exit('Error')
            features_train.append(feature)
            labels_train.append(label)

    features_test, labels_test = [], []
    with open(data_path + f'diagnose_test.json', 'r', encoding='utf-8') as f:
        for line in f.readlines():
            line = json.loads(line)
            feature = [-1] * len(symptom2id) if feature_type == 'scalar' else [[0,0]] * len(symptom2id)
            # label = int(disease2id[line['disease']])
            label = line['disease']

            for k, v in line['symptoms'].items():
                if v == 'True':
                    feature[symptom2id[k]] = 1 if feature_type == 'scalar' else [1,0]
                elif v == 'False':
                    feature[symptom2id[k]] = 0 if feature_type == 'scalar' else [0,1]
                else:
                    exit('Error')
            features_test.append(feature)
            labels_test.append(label)

    features_valid, labels_valid = [], []
    with open(data_path + f'diagnose_valid.json', 'r', encoding='utf-8') as f:
        for line in f.readlines():
            line = json.loads(line)
            feature = [-1] * len(symptom2id) if feature_type == 'scalar' else [[0,0]] * len(symptom2id)
            # label = int(disease2id[line['disease']])
            label = line['disease']

            for k, v in line['symptoms'].items():
                if v == 'True':
                    feature[symptom2id[k]] = 1 if feature_type == 'scalar' else [1,0]
                elif v == 'False':
                    feature[symptom2id[k]] = 0 if feature_type == 'scalar' else [0,1]
                else:
                    exit('Error')
            features_valid.append(feature)
            labels_valid.append(label)
    return features_train, labels_train, features_valid, labels_valid, features_test, labels_test, symptom2id, id2symptom, disease2id, id2disease

def get_rank_score(clf, X, y):
    cls_prob = clf.predict_proba(X)
    # print(cls_prob)
    assert len(cls_prob) == len(y)

    mrr_list, hits_1_cnt, hits_2_cnt = [], 0, 0
    for i in range(len(cls_prob)):
        goals_dict = {clf.classes_[id]:prob for id,prob in enumerate(cls_prob[i])}
        target = goals_dict[y[i]]

        less_cnt = 0
        for k,v in goals_dict.items():
            if k != y[i] and v < target:
                less_cnt += 1
        rank = len(goals_dict) - less_cnt # worst case 
        assert rank >= 1 and rank <= len(goals_dict)

        if rank == 1:
            hits_1, hits_2 = 1, 1
        elif rank == 2:
            hits_1, hits_2 = 0, 1
        else:
            hits_1, hits_2 = 0, 0

        mrr_list.append(1.0/rank)
        hits_1_cnt += hits_1
        hits_2_cnt += hits_2

    return np.mean(mrr_list), 1.0*hits_1_cnt/len(y), 1.0*hits_2_cnt/len(y)

def main():
    CogKG_path = '/home/weizhepei/workspace/CogKG/'
    data_path = CogKG_path + 'data/diagnose/aligned/'
    features_train, labels_train, features_valid, labels_valid, features_test, labels_test, symp2id, id2symp, dise2id, id2dise = load_data(data_path, feature_type='scalar')
    print(f'Train set size:{len(labels_train)}')
    print(f'Valid set size:{len(labels_valid)}')
    print(f'Test set size:{len(labels_test)}')
    print(f'Feature Dimension:{len(features_train[0])}')
    print(f'Label num:{len(id2dise)}')

    # Convert data to numpy array
    X_train=np.asarray(features_train)
    y_train=np.asarray(labels_train)
    X_train = X_train.reshape(X_train.shape[0],-1)

    X_valid=np.asarray(features_valid)
    y_valid=np.asarray(labels_valid)
    X_valid = X_valid.reshape(X_valid.shape[0],-1)

    X_test=np.asarray(features_test)
    y_test=np.asarray(labels_test)
    X_test = X_test.reshape(X_test.shape[0],-1)

    print(f'train: {X_train.shape}, {y_train.shape}')
    print(f'valid: {X_valid.shape}, {y_valid.shape}')
    print(f'test: {X_test.shape}, {y_test.shape}')

    MULTICLASS_INTERACTION = 0
    RANDOM_STATE = 2021
    WORKERS = 16
    best_acc = 0
    best_binning = None
    best_lr = 0
    best_min_samples_leaf = 0
    best_max_leaves = 0
    best_inner_bags = 0
    best_outer_bags = 0
    for binning in tqdm(['uniform', 'quantile', 'quantile_humanized']):
        for lr in [1e-4, 1e-3, 1e-2]:
            for min_samples_leaf in [2,3,4,5]:
                for max_leaves in [2,3,4,5]:
                    for inner_bags in range(11):
                        for outer_bags in range(1, 11):
                            clf = ExplainableBoostingClassifier(interactions=MULTICLASS_INTERACTION, random_state=RANDOM_STATE, n_jobs=WORKERS, binning=binning, learning_rate=lr, min_samples_leaf=min_samples_leaf, max_leaves=max_leaves, inner_bags=inner_bags, outer_bags=outer_bags)

                            clf.fit(X_train, y_train)
                            mrr, hits_1, hits_2 = get_rank_score(clf, X_valid, y_valid)
                            if hits_1 > best_acc:
                                best_acc = hits_1
                                best_binning = binning
                                best_lr = lr
                                best_min_samples_leaf = min_samples_leaf
                                best_max_leaves = max_leaves
                                best_inner_bags = inner_bags
                                best_outer_bags = outer_bags

    print(f"best binning = {best_binning}; best_lr={best_lr}; best_min_samples_leaf:{best_min_samples_leaf}; best_max_leaves:{best_max_leaves}; best_inner_bags:{best_inner_bags}; best_outer_bags:{best_outer_bags} validation acc:{best_acc}\n")

    clf = ExplainableBoostingClassifier(interactions=MULTICLASS_INTERACTION, random_state=RANDOM_STATE, n_jobs=WORKERS, binning=best_binning, learning_rate=best_lr, min_samples_leaf=best_min_samples_leaf, max_leaves=best_max_leaves, inner_bags=best_inner_bags, outer_bags=best_outer_bags)

    clf.fit(X_train, y_train)
    mrr, hits_1, hits_2 = get_rank_score(clf, X_test, y_test)

    y_pred = clf.predict(X_test)
    macro_f1 = f1_score(y_test, y_pred, average='macro')
    macro_p = precision_score(y_test, y_pred, average='macro')
    macro_r = recall_score(y_test, y_pred, average='macro')

    print(f'macro p: {macro_p}, r: {macro_r}, f1: {macro_f1}\n')

    acc_global = hits_1
    acc_local = hits_1
    psu_f1_score = 2*acc_local*1 / (1 + acc_local)
    print(f"avg acc:{acc_global}")
    print(f"f1_score:{psu_f1_score}")
    print(f"avg mrr:{mrr}")
    print(f"avg Hits@1:{hits_1}")
    print(f"avg Hits@2:{hits_2}")

    PERFORMANCE = {'Acc_global':acc_global, 'Acc_local':acc_local, 'F1_score':psu_f1_score, 'Hits@1':hits_1, 'Hits@2':hits_2, 'MRR':mrr, 'Coverage':1}

    with open('../PERFORMANCE_EBM.json', 'w', encoding='utf-8') as f:
        f.write(json.dumps(PERFORMANCE, ensure_ascii=False, indent=4))


if __name__ == '__main__':
    main()

